package test import ( "context" "encoding/base64" "fmt" "net/http" "net/http/httptest" "regexp" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" "github.com/kms/api-key-service/internal/auth" "github.com/kms/api-key-service/internal/domain" ) // mockSAMLMetadata returns a mock SAML IdP metadata XML func mockSAMLMetadata() string { return ` MIICertificateData ` } // mockSAMLResponse returns a mock SAML response XML with current timestamps func mockSAMLResponse() string { now := time.Now().UTC() issueInstant := now.Format(time.RFC3339) notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339) notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339) return fmt.Sprintf(` https://idp.example.com https://idp.example.com user@example.com https://sp.example.com user@example.com Test User Test User admin,user urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport `, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant) } func TestSAMLProvider_GetMetadata(t *testing.T) { tests := []struct { name string metadataURL string serverResponse string serverStatus int expectError bool errorContains string }{ { name: "successful metadata fetch", metadataURL: "https://idp.example.com/.well-known/saml-metadata", serverResponse: mockSAMLMetadata(), serverStatus: http.StatusOK, expectError: false, }, { name: "missing metadata URL", metadataURL: "", expectError: true, errorContains: "SAML_IDP_METADATA_URL not configured", }, { name: "server error", metadataURL: "https://idp.example.com/.well-known/saml-metadata", serverStatus: http.StatusInternalServerError, expectError: true, errorContains: "returned status 500", }, { name: "invalid XML", metadataURL: "https://idp.example.com/.well-known/saml-metadata", serverResponse: "invalid xml", serverStatus: http.StatusOK, expectError: true, errorContains: "Failed to parse SAML metadata", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create mock HTTP server var server *httptest.Server if tt.metadataURL != "" && tt.serverStatus > 0 { server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(tt.serverStatus) if tt.serverResponse != "" { w.Write([]byte(tt.serverResponse)) } })) defer server.Close() tt.metadataURL = server.URL } // Create config cfg := NewTestConfig() cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL // Create SAML provider logger := zaptest.NewLogger(t) provider, err := auth.NewSAMLProvider(cfg, logger) require.NoError(t, err) // Test GetMetadata ctx := context.Background() metadata, err := provider.GetMetadata(ctx) if tt.expectError { assert.Error(t, err) if tt.errorContains != "" { assert.Contains(t, err.Error(), tt.errorContains) } assert.Nil(t, metadata) } else { assert.NoError(t, err) assert.NotNil(t, metadata) assert.Equal(t, "https://idp.example.com", metadata.EntityID) assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService) } }) } } func TestSAMLProvider_GenerateAuthRequest(t *testing.T) { tests := []struct { name string spEntityID string acsURL string relayState string expectError bool errorContains string }{ { name: "successful auth request generation", spEntityID: "https://sp.example.com", acsURL: "https://sp.example.com/acs", relayState: "test-relay-state", }, { name: "missing SP entity ID", spEntityID: "", acsURL: "https://sp.example.com/acs", expectError: true, errorContains: "SAML_SP_ENTITY_ID not configured", }, { name: "missing ACS URL", spEntityID: "https://sp.example.com", acsURL: "", expectError: true, errorContains: "SAML_SP_ACS_URL not configured", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create mock HTTP server for metadata server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(mockSAMLMetadata())) })) defer server.Close() // Create config cfg := NewTestConfig() cfg.values["SAML_IDP_METADATA_URL"] = server.URL cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID cfg.values["SAML_SP_ACS_URL"] = tt.acsURL // Create SAML provider logger := zaptest.NewLogger(t) provider, err := auth.NewSAMLProvider(cfg, logger) require.NoError(t, err) // Test GenerateAuthRequest ctx := context.Background() authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState) if tt.expectError { assert.Error(t, err) if tt.errorContains != "" { assert.Contains(t, err.Error(), tt.errorContains) } assert.Empty(t, authURL) assert.Empty(t, requestID) } else { assert.NoError(t, err) assert.NotEmpty(t, authURL) assert.NotEmpty(t, requestID) assert.Contains(t, authURL, "https://idp.example.com/sso") assert.Contains(t, authURL, "SAMLRequest=") if tt.relayState != "" { assert.Contains(t, authURL, "RelayState="+tt.relayState) } } }) } } func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) { tests := []struct { name string samlResponse string expectedRequestID string spEntityID string expectError bool errorContains string expectedUserID string expectedEmail string expectedName string expectedRoles []string }{ { name: "successful SAML response processing", samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())), expectedRequestID: "_request_id", spEntityID: "https://sp.example.com", expectedUserID: "user@example.com", expectedEmail: "user@example.com", expectedName: "Test User", expectedRoles: []string{"admin", "user"}, }, { name: "invalid base64 encoding", samlResponse: "invalid-base64", expectError: true, errorContains: "Failed to decode SAML response", }, { name: "invalid XML", samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")), expectError: true, errorContains: "Failed to parse SAML response", }, { name: "audience mismatch", samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())), spEntityID: "https://wrong-sp.example.com", expectError: true, errorContains: "audience mismatch", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create config cfg := NewTestConfig() cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID // Create SAML provider logger := zaptest.NewLogger(t) provider, err := auth.NewSAMLProvider(cfg, logger) require.NoError(t, err) // Test ProcessSAMLResponse ctx := context.Background() authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID) if tt.expectError { assert.Error(t, err) if tt.errorContains != "" { assert.Contains(t, err.Error(), tt.errorContains) } assert.Nil(t, authContext) } else { assert.NoError(t, err) assert.NotNil(t, authContext) assert.Equal(t, tt.expectedUserID, authContext.UserID) assert.Equal(t, domain.TokenTypeUser, authContext.TokenType) // Check claims if tt.expectedEmail != "" { assert.Equal(t, tt.expectedEmail, authContext.Claims["email"]) } if tt.expectedName != "" { assert.Equal(t, tt.expectedName, authContext.Claims["name"]) } // Check permissions/roles if len(tt.expectedRoles) > 0 { assert.Equal(t, tt.expectedRoles, authContext.Permissions) } } }) } } func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) { tests := []struct { name string spEntityID string acsURL string expectError bool errorContains string }{ { name: "successful SP metadata generation", spEntityID: "https://sp.example.com", acsURL: "https://sp.example.com/acs", }, { name: "missing SP entity ID", spEntityID: "", acsURL: "https://sp.example.com/acs", expectError: true, errorContains: "SAML_SP_ENTITY_ID not configured", }, { name: "missing ACS URL", spEntityID: "https://sp.example.com", acsURL: "", expectError: true, errorContains: "SAML_SP_ACS_URL not configured", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create config cfg := NewTestConfig() cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID cfg.values["SAML_SP_ACS_URL"] = tt.acsURL // Create SAML provider logger := zaptest.NewLogger(t) provider, err := auth.NewSAMLProvider(cfg, logger) require.NoError(t, err) // Test GenerateServiceProviderMetadata metadata, err := provider.GenerateServiceProviderMetadata() if tt.expectError { assert.Error(t, err) if tt.errorContains != "" { assert.Contains(t, err.Error(), tt.errorContains) } assert.Empty(t, metadata) } else { assert.NoError(t, err) assert.NotEmpty(t, metadata) assert.Contains(t, metadata, tt.spEntityID) assert.Contains(t, metadata, tt.acsURL) assert.Contains(t, metadata, "EntityDescriptor") assert.Contains(t, metadata, "SPSSODescriptor") } }) } } // Benchmark tests for SAML operations func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) { // Create config cfg := NewTestConfig() cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com" // Create SAML provider logger := zaptest.NewLogger(b) provider, err := auth.NewSAMLProvider(cfg, logger) require.NoError(b, err) // Prepare SAML response samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())) ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { _, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id") if err != nil { b.Fatal(err) } } } func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) { // Create mock HTTP server for metadata server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(mockSAMLMetadata())) })) defer server.Close() // Create config cfg := NewTestConfig() cfg.values["SAML_IDP_METADATA_URL"] = server.URL cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com" cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs" // Create SAML provider logger := zaptest.NewLogger(b) provider, err := auth.NewSAMLProvider(cfg, logger) require.NoError(b, err) ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { _, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state") if err != nil { b.Fatal(err) } } } // Test helper functions func TestSAMLResponseValidation(t *testing.T) { // Test various SAML response validation scenarios tests := []struct { name string modifyXML func(string) string expectError bool errorContains string }{ { name: "expired assertion", modifyXML: func(xml string) string { // Replace all NotOnOrAfter timestamps with past time pastTime := "2020-01-01T13:00:00Z" re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`) return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`) }, expectError: true, errorContains: "assertion has expired", }, { name: "assertion not yet valid", modifyXML: func(xml string) string { // Replace all NotBefore timestamps with future time futureTime := "2030-01-01T11:55:00Z" re := regexp.MustCompile(`NotBefore="[^"]*"`) return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`) }, expectError: true, errorContains: "assertion not yet valid", }, { name: "failed status", modifyXML: func(xml string) string { return strings.ReplaceAll(xml, "urn:oasis:names:tc:SAML:2.0:status:Success", "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed") }, expectError: true, errorContains: "SAML authentication failed", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create config cfg := NewTestConfig() cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com" // Create SAML provider logger := zaptest.NewLogger(t) provider, err := auth.NewSAMLProvider(cfg, logger) require.NoError(t, err) // Modify SAML response modifiedXML := tt.modifyXML(mockSAMLResponse()) samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML)) // Test ProcessSAMLResponse ctx := context.Background() authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id") if tt.expectError { assert.Error(t, err) if tt.errorContains != "" { assert.Contains(t, err.Error(), tt.errorContains) } assert.Nil(t, authContext) } else { assert.NoError(t, err) assert.NotNil(t, authContext) } }) } }