Files
skybridge/test/saml_test.go
2025-08-22 18:57:40 -04:00

533 lines
16 KiB
Go

package test
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/domain"
)
// mockSAMLMetadata returns a mock SAML IdP metadata XML
func mockSAMLMetadata() string {
return `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="https://idp.example.com">
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:KeyDescriptor use="signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>MIICertificateData</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>`
}
// mockSAMLResponse returns a mock SAML response XML with current timestamps
func mockSAMLResponse() string {
now := time.Now().UTC()
issueInstant := now.Format(time.RFC3339)
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="_response_id" Version="2.0" IssueInstant="%s"
Destination="https://sp.example.com/acs" InResponseTo="_request_id">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<samlp:Status>
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
</samlp:Status>
<saml:Assertion ID="_assertion_id" Version="2.0" IssueInstant="%s">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<saml:Subject>
<saml:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress">user@example.com</saml:NameID>
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
<saml:SubjectConfirmationData InResponseTo="_request_id" NotOnOrAfter="%s" Recipient="https://sp.example.com/acs"/>
</saml:SubjectConfirmation>
</saml:Subject>
<saml:Conditions NotBefore="%s" NotOnOrAfter="%s">
<saml:AudienceRestriction>
<saml:Audience>https://sp.example.com</saml:Audience>
</saml:AudienceRestriction>
</saml:Conditions>
<saml:AttributeStatement>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress">
<saml:AttributeValue>user@example.com</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name">
<saml:AttributeValue>Test User</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname">
<saml:AttributeValue>Test</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname">
<saml:AttributeValue>User</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.microsoft.com/ws/2008/06/identity/claims/role">
<saml:AttributeValue>admin,user</saml:AttributeValue>
</saml:Attribute>
</saml:AttributeStatement>
<saml:AuthnStatement AuthnInstant="%s" SessionIndex="_session_index">
<saml:AuthnContext>
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
</saml:AuthnContext>
</saml:AuthnStatement>
</saml:Assertion>
</samlp:Response>`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
}
func TestSAMLProvider_GetMetadata(t *testing.T) {
tests := []struct {
name string
metadataURL string
serverResponse string
serverStatus int
expectError bool
errorContains string
}{
{
name: "successful metadata fetch",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: mockSAMLMetadata(),
serverStatus: http.StatusOK,
expectError: false,
},
{
name: "missing metadata URL",
metadataURL: "",
expectError: true,
errorContains: "SAML_IDP_METADATA_URL not configured",
},
{
name: "server error",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverStatus: http.StatusInternalServerError,
expectError: true,
errorContains: "returned status 500",
},
{
name: "invalid XML",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: "invalid xml",
serverStatus: http.StatusOK,
expectError: true,
errorContains: "Failed to parse SAML metadata",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server
var server *httptest.Server
if tt.metadataURL != "" && tt.serverStatus > 0 {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.serverStatus)
if tt.serverResponse != "" {
w.Write([]byte(tt.serverResponse))
}
}))
defer server.Close()
tt.metadataURL = server.URL
}
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GetMetadata
ctx := context.Background()
metadata, err := provider.GetMetadata(ctx)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, metadata)
} else {
assert.NoError(t, err)
assert.NotNil(t, metadata)
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
}
})
}
}
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
relayState string
expectError bool
errorContains string
}{
{
name: "successful auth request generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
relayState: "test-relay-state",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateAuthRequest
ctx := context.Background()
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, authURL)
assert.Empty(t, requestID)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, authURL)
assert.NotEmpty(t, requestID)
assert.Contains(t, authURL, "https://idp.example.com/sso")
assert.Contains(t, authURL, "SAMLRequest=")
if tt.relayState != "" {
assert.Contains(t, authURL, "RelayState="+tt.relayState)
}
}
})
}
}
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
tests := []struct {
name string
samlResponse string
expectedRequestID string
spEntityID string
expectError bool
errorContains string
expectedUserID string
expectedEmail string
expectedName string
expectedRoles []string
}{
{
name: "successful SAML response processing",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
expectedRequestID: "_request_id",
spEntityID: "https://sp.example.com",
expectedUserID: "user@example.com",
expectedEmail: "user@example.com",
expectedName: "Test User",
expectedRoles: []string{"admin", "user"},
},
{
name: "invalid base64 encoding",
samlResponse: "invalid-base64",
expectError: true,
errorContains: "Failed to decode SAML response",
},
{
name: "invalid XML",
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
expectError: true,
errorContains: "Failed to parse SAML response",
},
{
name: "audience mismatch",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
spEntityID: "https://wrong-sp.example.com",
expectError: true,
errorContains: "audience mismatch",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
assert.Equal(t, tt.expectedUserID, authContext.UserID)
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
// Check claims
if tt.expectedEmail != "" {
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
}
if tt.expectedName != "" {
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
}
// Check permissions/roles
if len(tt.expectedRoles) > 0 {
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
}
}
})
}
}
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
expectError bool
errorContains string
}{
{
name: "successful SP metadata generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateServiceProviderMetadata
metadata, err := provider.GenerateServiceProviderMetadata()
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, metadata)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, metadata)
assert.Contains(t, metadata, tt.spEntityID)
assert.Contains(t, metadata, tt.acsURL)
assert.Contains(t, metadata, "EntityDescriptor")
assert.Contains(t, metadata, "SPSSODescriptor")
}
})
}
}
// Benchmark tests for SAML operations
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
// Prepare SAML response
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
if err != nil {
b.Fatal(err)
}
}
}
// Test helper functions
func TestSAMLResponseValidation(t *testing.T) {
// Test various SAML response validation scenarios
tests := []struct {
name string
modifyXML func(string) string
expectError bool
errorContains string
}{
{
name: "expired assertion",
modifyXML: func(xml string) string {
// Replace all NotOnOrAfter timestamps with past time
pastTime := "2020-01-01T13:00:00Z"
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
},
expectError: true,
errorContains: "assertion has expired",
},
{
name: "assertion not yet valid",
modifyXML: func(xml string) string {
// Replace all NotBefore timestamps with future time
futureTime := "2030-01-01T11:55:00Z"
re := regexp.MustCompile(`NotBefore="[^"]*"`)
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
},
expectError: true,
errorContains: "assertion not yet valid",
},
{
name: "failed status",
modifyXML: func(xml string) string {
return strings.ReplaceAll(xml,
"urn:oasis:names:tc:SAML:2.0:status:Success",
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
},
expectError: true,
errorContains: "SAML authentication failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Modify SAML response
modifiedXML := tt.modifyXML(mockSAMLResponse())
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
}
})
}
}