v2
This commit is contained in:
532
test/saml_test.go
Normal file
532
test/saml_test.go
Normal file
@ -0,0 +1,532 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
)
|
||||
|
||||
// mockSAMLMetadata returns a mock SAML IdP metadata XML
|
||||
func mockSAMLMetadata() string {
|
||||
return `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="https://idp.example.com">
|
||||
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<md:KeyDescriptor use="signing">
|
||||
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
|
||||
<ds:X509Data>
|
||||
<ds:X509Certificate>MIICertificateData</ds:X509Certificate>
|
||||
</ds:X509Data>
|
||||
</ds:KeyInfo>
|
||||
</md:KeyDescriptor>
|
||||
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
|
||||
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
|
||||
</md:IDPSSODescriptor>
|
||||
</md:EntityDescriptor>`
|
||||
}
|
||||
|
||||
// mockSAMLResponse returns a mock SAML response XML with current timestamps
|
||||
func mockSAMLResponse() string {
|
||||
now := time.Now().UTC()
|
||||
issueInstant := now.Format(time.RFC3339)
|
||||
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
|
||||
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
|
||||
ID="_response_id" Version="2.0" IssueInstant="%s"
|
||||
Destination="https://sp.example.com/acs" InResponseTo="_request_id">
|
||||
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
||||
<samlp:Status>
|
||||
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
|
||||
</samlp:Status>
|
||||
<saml:Assertion ID="_assertion_id" Version="2.0" IssueInstant="%s">
|
||||
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
||||
<saml:Subject>
|
||||
<saml:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress">user@example.com</saml:NameID>
|
||||
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
|
||||
<saml:SubjectConfirmationData InResponseTo="_request_id" NotOnOrAfter="%s" Recipient="https://sp.example.com/acs"/>
|
||||
</saml:SubjectConfirmation>
|
||||
</saml:Subject>
|
||||
<saml:Conditions NotBefore="%s" NotOnOrAfter="%s">
|
||||
<saml:AudienceRestriction>
|
||||
<saml:Audience>https://sp.example.com</saml:Audience>
|
||||
</saml:AudienceRestriction>
|
||||
</saml:Conditions>
|
||||
<saml:AttributeStatement>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress">
|
||||
<saml:AttributeValue>user@example.com</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name">
|
||||
<saml:AttributeValue>Test User</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname">
|
||||
<saml:AttributeValue>Test</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname">
|
||||
<saml:AttributeValue>User</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.microsoft.com/ws/2008/06/identity/claims/role">
|
||||
<saml:AttributeValue>admin,user</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
</saml:AttributeStatement>
|
||||
<saml:AuthnStatement AuthnInstant="%s" SessionIndex="_session_index">
|
||||
<saml:AuthnContext>
|
||||
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
|
||||
</saml:AuthnContext>
|
||||
</saml:AuthnStatement>
|
||||
</saml:Assertion>
|
||||
</samlp:Response>`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GetMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadataURL string
|
||||
serverResponse string
|
||||
serverStatus int
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful metadata fetch",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverResponse: mockSAMLMetadata(),
|
||||
serverStatus: http.StatusOK,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing metadata URL",
|
||||
metadataURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_IDP_METADATA_URL not configured",
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverStatus: http.StatusInternalServerError,
|
||||
expectError: true,
|
||||
errorContains: "returned status 500",
|
||||
},
|
||||
{
|
||||
name: "invalid XML",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverResponse: "invalid xml",
|
||||
serverStatus: http.StatusOK,
|
||||
expectError: true,
|
||||
errorContains: "Failed to parse SAML metadata",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock HTTP server
|
||||
var server *httptest.Server
|
||||
if tt.metadataURL != "" && tt.serverStatus > 0 {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tt.serverStatus)
|
||||
if tt.serverResponse != "" {
|
||||
w.Write([]byte(tt.serverResponse))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
tt.metadataURL = server.URL
|
||||
}
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GetMetadata
|
||||
ctx := context.Background()
|
||||
metadata, err := provider.GetMetadata(ctx)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, metadata)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, metadata)
|
||||
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
|
||||
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spEntityID string
|
||||
acsURL string
|
||||
relayState string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful auth request generation",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
relayState: "test-relay-state",
|
||||
},
|
||||
{
|
||||
name: "missing SP entity ID",
|
||||
spEntityID: "",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ENTITY_ID not configured",
|
||||
},
|
||||
{
|
||||
name: "missing ACS URL",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ACS_URL not configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock HTTP server for metadata
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockSAMLMetadata()))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GenerateAuthRequest
|
||||
ctx := context.Background()
|
||||
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Empty(t, authURL)
|
||||
assert.Empty(t, requestID)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, authURL)
|
||||
assert.NotEmpty(t, requestID)
|
||||
assert.Contains(t, authURL, "https://idp.example.com/sso")
|
||||
assert.Contains(t, authURL, "SAMLRequest=")
|
||||
if tt.relayState != "" {
|
||||
assert.Contains(t, authURL, "RelayState="+tt.relayState)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
samlResponse string
|
||||
expectedRequestID string
|
||||
spEntityID string
|
||||
expectError bool
|
||||
errorContains string
|
||||
expectedUserID string
|
||||
expectedEmail string
|
||||
expectedName string
|
||||
expectedRoles []string
|
||||
}{
|
||||
{
|
||||
name: "successful SAML response processing",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
||||
expectedRequestID: "_request_id",
|
||||
spEntityID: "https://sp.example.com",
|
||||
expectedUserID: "user@example.com",
|
||||
expectedEmail: "user@example.com",
|
||||
expectedName: "Test User",
|
||||
expectedRoles: []string{"admin", "user"},
|
||||
},
|
||||
{
|
||||
name: "invalid base64 encoding",
|
||||
samlResponse: "invalid-base64",
|
||||
expectError: true,
|
||||
errorContains: "Failed to decode SAML response",
|
||||
},
|
||||
{
|
||||
name: "invalid XML",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
|
||||
expectError: true,
|
||||
errorContains: "Failed to parse SAML response",
|
||||
},
|
||||
{
|
||||
name: "audience mismatch",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
||||
spEntityID: "https://wrong-sp.example.com",
|
||||
expectError: true,
|
||||
errorContains: "audience mismatch",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test ProcessSAMLResponse
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
assert.Equal(t, tt.expectedUserID, authContext.UserID)
|
||||
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
|
||||
|
||||
// Check claims
|
||||
if tt.expectedEmail != "" {
|
||||
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
|
||||
}
|
||||
if tt.expectedName != "" {
|
||||
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
|
||||
}
|
||||
|
||||
// Check permissions/roles
|
||||
if len(tt.expectedRoles) > 0 {
|
||||
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spEntityID string
|
||||
acsURL string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful SP metadata generation",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
},
|
||||
{
|
||||
name: "missing SP entity ID",
|
||||
spEntityID: "",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ENTITY_ID not configured",
|
||||
},
|
||||
{
|
||||
name: "missing ACS URL",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ACS_URL not configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GenerateServiceProviderMetadata
|
||||
metadata, err := provider.GenerateServiceProviderMetadata()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Empty(t, metadata)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, metadata)
|
||||
assert.Contains(t, metadata, tt.spEntityID)
|
||||
assert.Contains(t, metadata, tt.acsURL)
|
||||
assert.Contains(t, metadata, "EntityDescriptor")
|
||||
assert.Contains(t, metadata, "SPSSODescriptor")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for SAML operations
|
||||
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(b)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(b, err)
|
||||
|
||||
// Prepare SAML response
|
||||
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
|
||||
// Create mock HTTP server for metadata
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockSAMLMetadata()))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(b)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(b, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestSAMLResponseValidation(t *testing.T) {
|
||||
// Test various SAML response validation scenarios
|
||||
tests := []struct {
|
||||
name string
|
||||
modifyXML func(string) string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "expired assertion",
|
||||
modifyXML: func(xml string) string {
|
||||
// Replace all NotOnOrAfter timestamps with past time
|
||||
pastTime := "2020-01-01T13:00:00Z"
|
||||
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
|
||||
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "assertion has expired",
|
||||
},
|
||||
{
|
||||
name: "assertion not yet valid",
|
||||
modifyXML: func(xml string) string {
|
||||
// Replace all NotBefore timestamps with future time
|
||||
futureTime := "2030-01-01T11:55:00Z"
|
||||
re := regexp.MustCompile(`NotBefore="[^"]*"`)
|
||||
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "assertion not yet valid",
|
||||
},
|
||||
{
|
||||
name: "failed status",
|
||||
modifyXML: func(xml string) string {
|
||||
return strings.ReplaceAll(xml,
|
||||
"urn:oasis:names:tc:SAML:2.0:status:Success",
|
||||
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "SAML authentication failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify SAML response
|
||||
modifiedXML := tt.modifyXML(mockSAMLResponse())
|
||||
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
|
||||
|
||||
// Test ProcessSAMLResponse
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user