533 lines
16 KiB
Go
533 lines
16 KiB
Go
package test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap/zaptest"
|
|
|
|
"github.com/kms/api-key-service/internal/auth"
|
|
"github.com/kms/api-key-service/internal/domain"
|
|
)
|
|
|
|
// mockSAMLMetadata returns a mock SAML IdP metadata XML
|
|
func mockSAMLMetadata() string {
|
|
return `<?xml version="1.0" encoding="UTF-8"?>
|
|
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="https://idp.example.com">
|
|
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
|
<md:KeyDescriptor use="signing">
|
|
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
|
|
<ds:X509Data>
|
|
<ds:X509Certificate>MIICertificateData</ds:X509Certificate>
|
|
</ds:X509Data>
|
|
</ds:KeyInfo>
|
|
</md:KeyDescriptor>
|
|
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
|
|
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
|
|
</md:IDPSSODescriptor>
|
|
</md:EntityDescriptor>`
|
|
}
|
|
|
|
// mockSAMLResponse returns a mock SAML response XML with current timestamps
|
|
func mockSAMLResponse() string {
|
|
now := time.Now().UTC()
|
|
issueInstant := now.Format(time.RFC3339)
|
|
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
|
|
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
|
|
|
|
return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
|
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
|
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
|
|
ID="_response_id" Version="2.0" IssueInstant="%s"
|
|
Destination="https://sp.example.com/acs" InResponseTo="_request_id">
|
|
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
|
<samlp:Status>
|
|
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
|
|
</samlp:Status>
|
|
<saml:Assertion ID="_assertion_id" Version="2.0" IssueInstant="%s">
|
|
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
|
<saml:Subject>
|
|
<saml:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress">user@example.com</saml:NameID>
|
|
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
|
|
<saml:SubjectConfirmationData InResponseTo="_request_id" NotOnOrAfter="%s" Recipient="https://sp.example.com/acs"/>
|
|
</saml:SubjectConfirmation>
|
|
</saml:Subject>
|
|
<saml:Conditions NotBefore="%s" NotOnOrAfter="%s">
|
|
<saml:AudienceRestriction>
|
|
<saml:Audience>https://sp.example.com</saml:Audience>
|
|
</saml:AudienceRestriction>
|
|
</saml:Conditions>
|
|
<saml:AttributeStatement>
|
|
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress">
|
|
<saml:AttributeValue>user@example.com</saml:AttributeValue>
|
|
</saml:Attribute>
|
|
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name">
|
|
<saml:AttributeValue>Test User</saml:AttributeValue>
|
|
</saml:Attribute>
|
|
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname">
|
|
<saml:AttributeValue>Test</saml:AttributeValue>
|
|
</saml:Attribute>
|
|
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname">
|
|
<saml:AttributeValue>User</saml:AttributeValue>
|
|
</saml:Attribute>
|
|
<saml:Attribute Name="http://schemas.microsoft.com/ws/2008/06/identity/claims/role">
|
|
<saml:AttributeValue>admin,user</saml:AttributeValue>
|
|
</saml:Attribute>
|
|
</saml:AttributeStatement>
|
|
<saml:AuthnStatement AuthnInstant="%s" SessionIndex="_session_index">
|
|
<saml:AuthnContext>
|
|
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
|
|
</saml:AuthnContext>
|
|
</saml:AuthnStatement>
|
|
</saml:Assertion>
|
|
</samlp:Response>`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
|
|
}
|
|
|
|
func TestSAMLProvider_GetMetadata(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
metadataURL string
|
|
serverResponse string
|
|
serverStatus int
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "successful metadata fetch",
|
|
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
|
serverResponse: mockSAMLMetadata(),
|
|
serverStatus: http.StatusOK,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "missing metadata URL",
|
|
metadataURL: "",
|
|
expectError: true,
|
|
errorContains: "SAML_IDP_METADATA_URL not configured",
|
|
},
|
|
{
|
|
name: "server error",
|
|
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
|
serverStatus: http.StatusInternalServerError,
|
|
expectError: true,
|
|
errorContains: "returned status 500",
|
|
},
|
|
{
|
|
name: "invalid XML",
|
|
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
|
serverResponse: "invalid xml",
|
|
serverStatus: http.StatusOK,
|
|
expectError: true,
|
|
errorContains: "Failed to parse SAML metadata",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create mock HTTP server
|
|
var server *httptest.Server
|
|
if tt.metadataURL != "" && tt.serverStatus > 0 {
|
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(tt.serverStatus)
|
|
if tt.serverResponse != "" {
|
|
w.Write([]byte(tt.serverResponse))
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
tt.metadataURL = server.URL
|
|
}
|
|
|
|
// Create config
|
|
cfg := NewTestConfig()
|
|
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
|
|
|
|
// Create SAML provider
|
|
logger := zaptest.NewLogger(t)
|
|
provider, err := auth.NewSAMLProvider(cfg, logger)
|
|
require.NoError(t, err)
|
|
|
|
// Test GetMetadata
|
|
ctx := context.Background()
|
|
metadata, err := provider.GetMetadata(ctx)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errorContains)
|
|
}
|
|
assert.Nil(t, metadata)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, metadata)
|
|
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
|
|
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
spEntityID string
|
|
acsURL string
|
|
relayState string
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "successful auth request generation",
|
|
spEntityID: "https://sp.example.com",
|
|
acsURL: "https://sp.example.com/acs",
|
|
relayState: "test-relay-state",
|
|
},
|
|
{
|
|
name: "missing SP entity ID",
|
|
spEntityID: "",
|
|
acsURL: "https://sp.example.com/acs",
|
|
expectError: true,
|
|
errorContains: "SAML_SP_ENTITY_ID not configured",
|
|
},
|
|
{
|
|
name: "missing ACS URL",
|
|
spEntityID: "https://sp.example.com",
|
|
acsURL: "",
|
|
expectError: true,
|
|
errorContains: "SAML_SP_ACS_URL not configured",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create mock HTTP server for metadata
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(mockSAMLMetadata()))
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create config
|
|
cfg := NewTestConfig()
|
|
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
|
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
|
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
|
|
|
// Create SAML provider
|
|
logger := zaptest.NewLogger(t)
|
|
provider, err := auth.NewSAMLProvider(cfg, logger)
|
|
require.NoError(t, err)
|
|
|
|
// Test GenerateAuthRequest
|
|
ctx := context.Background()
|
|
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errorContains)
|
|
}
|
|
assert.Empty(t, authURL)
|
|
assert.Empty(t, requestID)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, authURL)
|
|
assert.NotEmpty(t, requestID)
|
|
assert.Contains(t, authURL, "https://idp.example.com/sso")
|
|
assert.Contains(t, authURL, "SAMLRequest=")
|
|
if tt.relayState != "" {
|
|
assert.Contains(t, authURL, "RelayState="+tt.relayState)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
samlResponse string
|
|
expectedRequestID string
|
|
spEntityID string
|
|
expectError bool
|
|
errorContains string
|
|
expectedUserID string
|
|
expectedEmail string
|
|
expectedName string
|
|
expectedRoles []string
|
|
}{
|
|
{
|
|
name: "successful SAML response processing",
|
|
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
|
expectedRequestID: "_request_id",
|
|
spEntityID: "https://sp.example.com",
|
|
expectedUserID: "user@example.com",
|
|
expectedEmail: "user@example.com",
|
|
expectedName: "Test User",
|
|
expectedRoles: []string{"admin", "user"},
|
|
},
|
|
{
|
|
name: "invalid base64 encoding",
|
|
samlResponse: "invalid-base64",
|
|
expectError: true,
|
|
errorContains: "Failed to decode SAML response",
|
|
},
|
|
{
|
|
name: "invalid XML",
|
|
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
|
|
expectError: true,
|
|
errorContains: "Failed to parse SAML response",
|
|
},
|
|
{
|
|
name: "audience mismatch",
|
|
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
|
spEntityID: "https://wrong-sp.example.com",
|
|
expectError: true,
|
|
errorContains: "audience mismatch",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create config
|
|
cfg := NewTestConfig()
|
|
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
|
|
|
// Create SAML provider
|
|
logger := zaptest.NewLogger(t)
|
|
provider, err := auth.NewSAMLProvider(cfg, logger)
|
|
require.NoError(t, err)
|
|
|
|
// Test ProcessSAMLResponse
|
|
ctx := context.Background()
|
|
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errorContains)
|
|
}
|
|
assert.Nil(t, authContext)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, authContext)
|
|
assert.Equal(t, tt.expectedUserID, authContext.UserID)
|
|
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
|
|
|
|
// Check claims
|
|
if tt.expectedEmail != "" {
|
|
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
|
|
}
|
|
if tt.expectedName != "" {
|
|
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
|
|
}
|
|
|
|
// Check permissions/roles
|
|
if len(tt.expectedRoles) > 0 {
|
|
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
spEntityID string
|
|
acsURL string
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "successful SP metadata generation",
|
|
spEntityID: "https://sp.example.com",
|
|
acsURL: "https://sp.example.com/acs",
|
|
},
|
|
{
|
|
name: "missing SP entity ID",
|
|
spEntityID: "",
|
|
acsURL: "https://sp.example.com/acs",
|
|
expectError: true,
|
|
errorContains: "SAML_SP_ENTITY_ID not configured",
|
|
},
|
|
{
|
|
name: "missing ACS URL",
|
|
spEntityID: "https://sp.example.com",
|
|
acsURL: "",
|
|
expectError: true,
|
|
errorContains: "SAML_SP_ACS_URL not configured",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create config
|
|
cfg := NewTestConfig()
|
|
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
|
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
|
|
|
// Create SAML provider
|
|
logger := zaptest.NewLogger(t)
|
|
provider, err := auth.NewSAMLProvider(cfg, logger)
|
|
require.NoError(t, err)
|
|
|
|
// Test GenerateServiceProviderMetadata
|
|
metadata, err := provider.GenerateServiceProviderMetadata()
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errorContains)
|
|
}
|
|
assert.Empty(t, metadata)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, metadata)
|
|
assert.Contains(t, metadata, tt.spEntityID)
|
|
assert.Contains(t, metadata, tt.acsURL)
|
|
assert.Contains(t, metadata, "EntityDescriptor")
|
|
assert.Contains(t, metadata, "SPSSODescriptor")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Benchmark tests for SAML operations
|
|
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
|
|
// Create config
|
|
cfg := NewTestConfig()
|
|
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
|
|
|
// Create SAML provider
|
|
logger := zaptest.NewLogger(b)
|
|
provider, err := auth.NewSAMLProvider(cfg, logger)
|
|
require.NoError(b, err)
|
|
|
|
// Prepare SAML response
|
|
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
|
|
ctx := context.Background()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
|
|
// Create mock HTTP server for metadata
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(mockSAMLMetadata()))
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create config
|
|
cfg := NewTestConfig()
|
|
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
|
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
|
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
|
|
|
|
// Create SAML provider
|
|
logger := zaptest.NewLogger(b)
|
|
provider, err := auth.NewSAMLProvider(cfg, logger)
|
|
require.NoError(b, err)
|
|
|
|
ctx := context.Background()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test helper functions
|
|
func TestSAMLResponseValidation(t *testing.T) {
|
|
// Test various SAML response validation scenarios
|
|
tests := []struct {
|
|
name string
|
|
modifyXML func(string) string
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "expired assertion",
|
|
modifyXML: func(xml string) string {
|
|
// Replace all NotOnOrAfter timestamps with past time
|
|
pastTime := "2020-01-01T13:00:00Z"
|
|
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
|
|
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
|
|
},
|
|
expectError: true,
|
|
errorContains: "assertion has expired",
|
|
},
|
|
{
|
|
name: "assertion not yet valid",
|
|
modifyXML: func(xml string) string {
|
|
// Replace all NotBefore timestamps with future time
|
|
futureTime := "2030-01-01T11:55:00Z"
|
|
re := regexp.MustCompile(`NotBefore="[^"]*"`)
|
|
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
|
|
},
|
|
expectError: true,
|
|
errorContains: "assertion not yet valid",
|
|
},
|
|
{
|
|
name: "failed status",
|
|
modifyXML: func(xml string) string {
|
|
return strings.ReplaceAll(xml,
|
|
"urn:oasis:names:tc:SAML:2.0:status:Success",
|
|
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
|
|
},
|
|
expectError: true,
|
|
errorContains: "SAML authentication failed",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create config
|
|
cfg := NewTestConfig()
|
|
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
|
|
|
// Create SAML provider
|
|
logger := zaptest.NewLogger(t)
|
|
provider, err := auth.NewSAMLProvider(cfg, logger)
|
|
require.NoError(t, err)
|
|
|
|
// Modify SAML response
|
|
modifiedXML := tt.modifyXML(mockSAMLResponse())
|
|
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
|
|
|
|
// Test ProcessSAMLResponse
|
|
ctx := context.Background()
|
|
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorContains != "" {
|
|
assert.Contains(t, err.Error(), tt.errorContains)
|
|
}
|
|
assert.Nil(t, authContext)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, authContext)
|
|
}
|
|
})
|
|
}
|
|
}
|