package test
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/domain"
)
// mockSAMLMetadata returns a mock SAML IdP metadata XML
func mockSAMLMetadata() string {
return `
MIICertificateData
`
}
// mockSAMLResponse returns a mock SAML response XML with current timestamps
func mockSAMLResponse() string {
now := time.Now().UTC()
issueInstant := now.Format(time.RFC3339)
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
return fmt.Sprintf(`
https://idp.example.com
https://idp.example.com
user@example.com
https://sp.example.com
user@example.com
Test User
Test
User
admin,user
urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport
`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
}
func TestSAMLProvider_GetMetadata(t *testing.T) {
tests := []struct {
name string
metadataURL string
serverResponse string
serverStatus int
expectError bool
errorContains string
}{
{
name: "successful metadata fetch",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: mockSAMLMetadata(),
serverStatus: http.StatusOK,
expectError: false,
},
{
name: "missing metadata URL",
metadataURL: "",
expectError: true,
errorContains: "SAML_IDP_METADATA_URL not configured",
},
{
name: "server error",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverStatus: http.StatusInternalServerError,
expectError: true,
errorContains: "returned status 500",
},
{
name: "invalid XML",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: "invalid xml",
serverStatus: http.StatusOK,
expectError: true,
errorContains: "Failed to parse SAML metadata",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server
var server *httptest.Server
if tt.metadataURL != "" && tt.serverStatus > 0 {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.serverStatus)
if tt.serverResponse != "" {
w.Write([]byte(tt.serverResponse))
}
}))
defer server.Close()
tt.metadataURL = server.URL
}
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GetMetadata
ctx := context.Background()
metadata, err := provider.GetMetadata(ctx)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, metadata)
} else {
assert.NoError(t, err)
assert.NotNil(t, metadata)
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
}
})
}
}
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
relayState string
expectError bool
errorContains string
}{
{
name: "successful auth request generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
relayState: "test-relay-state",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateAuthRequest
ctx := context.Background()
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, authURL)
assert.Empty(t, requestID)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, authURL)
assert.NotEmpty(t, requestID)
assert.Contains(t, authURL, "https://idp.example.com/sso")
assert.Contains(t, authURL, "SAMLRequest=")
if tt.relayState != "" {
assert.Contains(t, authURL, "RelayState="+tt.relayState)
}
}
})
}
}
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
tests := []struct {
name string
samlResponse string
expectedRequestID string
spEntityID string
expectError bool
errorContains string
expectedUserID string
expectedEmail string
expectedName string
expectedRoles []string
}{
{
name: "successful SAML response processing",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
expectedRequestID: "_request_id",
spEntityID: "https://sp.example.com",
expectedUserID: "user@example.com",
expectedEmail: "user@example.com",
expectedName: "Test User",
expectedRoles: []string{"admin", "user"},
},
{
name: "invalid base64 encoding",
samlResponse: "invalid-base64",
expectError: true,
errorContains: "Failed to decode SAML response",
},
{
name: "invalid XML",
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
expectError: true,
errorContains: "Failed to parse SAML response",
},
{
name: "audience mismatch",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
spEntityID: "https://wrong-sp.example.com",
expectError: true,
errorContains: "audience mismatch",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
assert.Equal(t, tt.expectedUserID, authContext.UserID)
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
// Check claims
if tt.expectedEmail != "" {
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
}
if tt.expectedName != "" {
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
}
// Check permissions/roles
if len(tt.expectedRoles) > 0 {
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
}
}
})
}
}
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
expectError bool
errorContains string
}{
{
name: "successful SP metadata generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateServiceProviderMetadata
metadata, err := provider.GenerateServiceProviderMetadata()
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, metadata)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, metadata)
assert.Contains(t, metadata, tt.spEntityID)
assert.Contains(t, metadata, tt.acsURL)
assert.Contains(t, metadata, "EntityDescriptor")
assert.Contains(t, metadata, "SPSSODescriptor")
}
})
}
}
// Benchmark tests for SAML operations
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
// Prepare SAML response
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
if err != nil {
b.Fatal(err)
}
}
}
// Test helper functions
func TestSAMLResponseValidation(t *testing.T) {
// Test various SAML response validation scenarios
tests := []struct {
name string
modifyXML func(string) string
expectError bool
errorContains string
}{
{
name: "expired assertion",
modifyXML: func(xml string) string {
// Replace all NotOnOrAfter timestamps with past time
pastTime := "2020-01-01T13:00:00Z"
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
},
expectError: true,
errorContains: "assertion has expired",
},
{
name: "assertion not yet valid",
modifyXML: func(xml string) string {
// Replace all NotBefore timestamps with future time
futureTime := "2030-01-01T11:55:00Z"
re := regexp.MustCompile(`NotBefore="[^"]*"`)
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
},
expectError: true,
errorContains: "assertion not yet valid",
},
{
name: "failed status",
modifyXML: func(xml string) string {
return strings.ReplaceAll(xml,
"urn:oasis:names:tc:SAML:2.0:status:Success",
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
},
expectError: true,
errorContains: "SAML authentication failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Modify SAML response
modifiedXML := tt.modifyXML(mockSAMLResponse())
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
}
})
}
}