package auth import ( "context" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" "encoding/xml" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/google/uuid" "go.uber.org/zap" "github.com/kms/api-key-service/internal/config" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/errors" ) // SAMLProvider represents a SAML 2.0 identity provider type SAMLProvider struct { config config.ConfigProvider logger *zap.Logger httpClient *http.Client privateKey *rsa.PrivateKey certificate *x509.Certificate } // NewSAMLProvider creates a new SAML provider func NewSAMLProvider(config config.ConfigProvider, logger *zap.Logger) (*SAMLProvider, error) { provider := &SAMLProvider{ config: config, logger: logger, httpClient: &http.Client{ Timeout: 30 * time.Second, }, } // Load SP private key and certificate if configured if err := provider.loadCredentials(); err != nil { return nil, err } return provider, nil } // SAMLMetadata represents SAML IdP metadata type SAMLMetadata struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"` EntityID string `xml:"entityID,attr"` IDPSSODescriptor IDPSSODescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata IDPSSODescriptor"` } // IDPSSODescriptor represents the IdP SSO descriptor type IDPSSODescriptor struct { ProtocolSupportEnumeration string `xml:"protocolSupportEnumeration,attr"` KeyDescriptor []KeyDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata KeyDescriptor"` SingleSignOnService []SingleSignOnService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleSignOnService"` SingleLogoutService []SingleLogoutService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleLogoutService"` } // KeyDescriptor represents a key descriptor type KeyDescriptor struct { Use string `xml:"use,attr"` KeyInfo KeyInfo `xml:"urn:xmldsig KeyInfo"` } // KeyInfo represents key information type KeyInfo struct { X509Data X509Data `xml:"urn:xmldsig X509Data"` } // X509Data represents X509 certificate data type X509Data struct { X509Certificate string `xml:"urn:xmldsig X509Certificate"` } // SingleSignOnService represents SSO service endpoint type SingleSignOnService struct { Binding string `xml:"Binding,attr"` Location string `xml:"Location,attr"` } // SingleLogoutService represents SLO service endpoint type SingleLogoutService struct { Binding string `xml:"Binding,attr"` Location string `xml:"Location,attr"` } // SAMLRequest represents a SAML authentication request type SAMLRequest struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol AuthnRequest"` ID string `xml:"ID,attr"` Version string `xml:"Version,attr"` IssueInstant time.Time `xml:"IssueInstant,attr"` Destination string `xml:"Destination,attr"` AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` ProtocolBinding string `xml:"ProtocolBinding,attr"` Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` NameIDPolicy NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"` } // Issuer represents the SAML issuer type Issuer struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` Value string `xml:",chardata"` } // NameIDPolicy represents the name ID policy type NameIDPolicy struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"` Format string `xml:"Format,attr"` } // SAMLResponse represents a SAML response type SAMLResponse struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"` ID string `xml:"ID,attr"` Version string `xml:"Version,attr"` IssueInstant time.Time `xml:"IssueInstant,attr"` Destination string `xml:"Destination,attr"` InResponseTo string `xml:"InResponseTo,attr"` Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` Status Status `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"` Assertion Assertion `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"` } // Status represents the SAML response status type Status struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"` StatusCode StatusCode `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"` } // StatusCode represents the status code type StatusCode struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"` Value string `xml:"Value,attr"` } // Assertion represents a SAML assertion type Assertion struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"` ID string `xml:"ID,attr"` Version string `xml:"Version,attr"` IssueInstant time.Time `xml:"IssueInstant,attr"` Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` Subject Subject `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"` Conditions Conditions `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"` AttributeStatement AttributeStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"` AuthnStatement AuthnStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"` } // Subject represents the assertion subject type Subject struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"` NameID NameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"` SubjectConfirmation SubjectConfirmation `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"` } // NameID represents the name identifier type NameID struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"` Format string `xml:"Format,attr"` Value string `xml:",chardata"` } // SubjectConfirmation represents subject confirmation type SubjectConfirmation struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"` Method string `xml:"Method,attr"` SubjectConfirmationData SubjectConfirmationData `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"` } // SubjectConfirmationData represents subject confirmation data type SubjectConfirmationData struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"` InResponseTo string `xml:"InResponseTo,attr"` NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"` Recipient string `xml:"Recipient,attr"` } // Conditions represents assertion conditions type Conditions struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"` NotBefore time.Time `xml:"NotBefore,attr"` NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"` AudienceRestriction AudienceRestriction `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"` } // AudienceRestriction represents audience restriction type AudienceRestriction struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"` Audience Audience `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"` } // Audience represents the intended audience type Audience struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"` Value string `xml:",chardata"` } // AttributeStatement represents attribute statement type AttributeStatement struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"` Attribute []Attribute `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"` } // Attribute represents a SAML attribute type Attribute struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"` Name string `xml:"Name,attr"` AttributeValue []AttributeValue `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"` } // AttributeValue represents an attribute value type AttributeValue struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"` Type string `xml:"http://www.w3.org/2001/XMLSchema-instance type,attr"` Value string `xml:",chardata"` } // AuthnStatement represents authentication statement type AuthnStatement struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"` AuthnInstant time.Time `xml:"AuthnInstant,attr"` SessionIndex string `xml:"SessionIndex,attr"` AuthnContext AuthnContext `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"` } // AuthnContext represents authentication context type AuthnContext struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"` AuthnContextClassRef string `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContextClassRef"` } // GetMetadata fetches the SAML IdP metadata func (p *SAMLProvider) GetMetadata(ctx context.Context) (*SAMLMetadata, error) { metadataURL := p.config.GetString("SAML_IDP_METADATA_URL") if metadataURL == "" { return nil, errors.NewConfigurationError("SAML_IDP_METADATA_URL not configured") } p.logger.Debug("Fetching SAML IdP metadata", zap.String("url", metadataURL)) req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) if err != nil { return nil, errors.NewInternalError("Failed to create metadata request").WithInternal(err) } resp, err := p.httpClient.Do(req) if err != nil { return nil, errors.NewInternalError("Failed to fetch IdP metadata").WithInternal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, errors.NewInternalError(fmt.Sprintf("Metadata endpoint returned status %d", resp.StatusCode)) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.NewInternalError("Failed to read metadata response").WithInternal(err) } var metadata SAMLMetadata if err := xml.Unmarshal(body, &metadata); err != nil { return nil, errors.NewInternalError("Failed to parse SAML metadata").WithInternal(err) } p.logger.Debug("SAML IdP metadata fetched successfully", zap.String("entity_id", metadata.EntityID)) return &metadata, nil } // GenerateAuthRequest generates a SAML authentication request func (p *SAMLProvider) GenerateAuthRequest(ctx context.Context, relayState string) (string, string, error) { metadata, err := p.GetMetadata(ctx) if err != nil { return "", "", err } // Find SSO endpoint var ssoEndpoint string for _, sso := range metadata.IDPSSODescriptor.SingleSignOnService { if sso.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" { ssoEndpoint = sso.Location break } } if ssoEndpoint == "" { return "", "", errors.NewConfigurationError("No HTTP-Redirect SSO endpoint found in IdP metadata") } // Generate request ID requestID := "_" + uuid.New().String() // Get SP configuration spEntityID := p.config.GetString("SAML_SP_ENTITY_ID") acsURL := p.config.GetString("SAML_SP_ACS_URL") if spEntityID == "" { return "", "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured") } if acsURL == "" { return "", "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured") } // Create SAML request samlRequest := SAMLRequest{ ID: requestID, Version: "2.0", IssueInstant: time.Now().UTC(), Destination: ssoEndpoint, AssertionConsumerServiceURL: acsURL, ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", Issuer: Issuer{ Value: spEntityID, }, NameIDPolicy: NameIDPolicy{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress", }, } // Marshal to XML xmlData, err := xml.MarshalIndent(samlRequest, "", " ") if err != nil { return "", "", errors.NewInternalError("Failed to marshal SAML request").WithInternal(err) } // Add XML declaration xmlRequest := `` + "\n" + string(xmlData) // Base64 encode and URL encode encodedRequest := base64.StdEncoding.EncodeToString([]byte(xmlRequest)) // Build redirect URL params := url.Values{ "SAMLRequest": {encodedRequest}, "RelayState": {relayState}, } redirectURL := ssoEndpoint + "?" + params.Encode() p.logger.Debug("Generated SAML authentication request", zap.String("request_id", requestID), zap.String("sso_endpoint", ssoEndpoint)) return redirectURL, requestID, nil } // ProcessSAMLResponse processes a SAML response and extracts user information func (p *SAMLProvider) ProcessSAMLResponse(ctx context.Context, samlResponse string, expectedRequestID string) (*domain.AuthContext, error) { p.logger.Debug("Processing SAML response") // Base64 decode the response decodedResponse, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { return nil, errors.NewValidationError("Failed to decode SAML response").WithInternal(err) } // Parse XML var response SAMLResponse if err := xml.Unmarshal(decodedResponse, &response); err != nil { return nil, errors.NewValidationError("Failed to parse SAML response").WithInternal(err) } // Validate response if err := p.validateSAMLResponse(&response, expectedRequestID); err != nil { return nil, err } // Extract user information from assertion authContext, err := p.extractUserInfo(&response.Assertion) if err != nil { return nil, err } p.logger.Debug("SAML response processed successfully", zap.String("user_id", authContext.UserID)) return authContext, nil } // validateSAMLResponse validates a SAML response func (p *SAMLProvider) validateSAMLResponse(response *SAMLResponse, expectedRequestID string) error { // Check status if response.Status.StatusCode.Value != "urn:oasis:names:tc:SAML:2.0:status:Success" { return errors.NewAuthenticationError("SAML authentication failed: " + response.Status.StatusCode.Value) } // Validate InResponseTo if expectedRequestID != "" && response.InResponseTo != expectedRequestID { return errors.NewValidationError("SAML response InResponseTo does not match request ID") } // Validate assertion conditions assertion := &response.Assertion now := time.Now().UTC() if now.Before(assertion.Conditions.NotBefore) { return errors.NewValidationError("SAML assertion not yet valid") } if now.After(assertion.Conditions.NotOnOrAfter) { return errors.NewValidationError("SAML assertion has expired") } // Validate audience expectedAudience := p.config.GetString("SAML_SP_ENTITY_ID") if assertion.Conditions.AudienceRestriction.Audience.Value != expectedAudience { return errors.NewValidationError("SAML assertion audience mismatch") } // In production, you should also validate the signature // This requires implementing XML signature validation return nil } // extractUserInfo extracts user information from SAML assertion func (p *SAMLProvider) extractUserInfo(assertion *Assertion) (*domain.AuthContext, error) { // Extract user ID from NameID userID := assertion.Subject.NameID.Value if userID == "" { return nil, errors.NewValidationError("SAML assertion missing NameID") } // Extract attributes claims := make(map[string]string) claims["sub"] = userID claims["name_id_format"] = assertion.Subject.NameID.Format // Process attribute statements for _, attr := range assertion.AttributeStatement.Attribute { if len(attr.AttributeValue) > 0 { // Use the first value if multiple values exist claims[attr.Name] = attr.AttributeValue[0].Value } } // Map common attributes to standard claims if email, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"]; exists { claims["email"] = email } if name, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"]; exists { claims["name"] = name } if givenName, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname"]; exists { claims["given_name"] = givenName } if surname, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname"]; exists { claims["family_name"] = surname } // Extract permissions/roles if available var permissions []string if roles, exists := claims["http://schemas.microsoft.com/ws/2008/06/identity/claims/role"]; exists { permissions = strings.Split(roles, ",") } authContext := &domain.AuthContext{ UserID: userID, TokenType: domain.TokenTypeUser, Claims: claims, Permissions: permissions, } return authContext, nil } // GenerateServiceProviderMetadata generates SP metadata XML func (p *SAMLProvider) GenerateServiceProviderMetadata() (string, error) { spEntityID := p.config.GetString("SAML_SP_ENTITY_ID") acsURL := p.config.GetString("SAML_SP_ACS_URL") if spEntityID == "" { return "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured") } if acsURL == "" { return "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured") } // This is a simplified SP metadata generation // In production, you should use a proper SAML library metadata := fmt.Sprintf(` `, spEntityID, acsURL) return metadata, nil } // loadCredentials loads SP private key and certificate func (p *SAMLProvider) loadCredentials() error { // Load private key if configured privateKeyPEM := p.config.GetString("SAML_SP_PRIVATE_KEY") if privateKeyPEM != "" { block, _ := pem.Decode([]byte(privateKeyPEM)) if block == nil { return errors.NewConfigurationError("Failed to decode SAML SP private key") } privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { // Try PKCS8 format key, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return errors.NewConfigurationError("Failed to parse SAML SP private key").WithInternal(err) } var ok bool privateKey, ok = key.(*rsa.PrivateKey) if !ok { return errors.NewConfigurationError("SAML SP private key is not RSA") } } p.privateKey = privateKey } // Load certificate if configured certificatePEM := p.config.GetString("SAML_SP_CERTIFICATE") if certificatePEM != "" { block, _ := pem.Decode([]byte(certificatePEM)) if block == nil { return errors.NewConfigurationError("Failed to decode SAML SP certificate") } certificate, err := x509.ParseCertificate(block.Bytes) if err != nil { return errors.NewConfigurationError("Failed to parse SAML SP certificate").WithInternal(err) } p.certificate = certificate } return nil }