Files
skybridge/internal/auth/saml.go
2025-08-22 18:57:40 -04:00

545 lines
19 KiB
Go

package auth
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
)
// SAMLProvider represents a SAML 2.0 identity provider
type SAMLProvider struct {
config config.ConfigProvider
logger *zap.Logger
httpClient *http.Client
privateKey *rsa.PrivateKey
certificate *x509.Certificate
}
// NewSAMLProvider creates a new SAML provider
func NewSAMLProvider(config config.ConfigProvider, logger *zap.Logger) (*SAMLProvider, error) {
provider := &SAMLProvider{
config: config,
logger: logger,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
// Load SP private key and certificate if configured
if err := provider.loadCredentials(); err != nil {
return nil, err
}
return provider, nil
}
// SAMLMetadata represents SAML IdP metadata
type SAMLMetadata struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"`
EntityID string `xml:"entityID,attr"`
IDPSSODescriptor IDPSSODescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata IDPSSODescriptor"`
}
// IDPSSODescriptor represents the IdP SSO descriptor
type IDPSSODescriptor struct {
ProtocolSupportEnumeration string `xml:"protocolSupportEnumeration,attr"`
KeyDescriptor []KeyDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata KeyDescriptor"`
SingleSignOnService []SingleSignOnService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleSignOnService"`
SingleLogoutService []SingleLogoutService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleLogoutService"`
}
// KeyDescriptor represents a key descriptor
type KeyDescriptor struct {
Use string `xml:"use,attr"`
KeyInfo KeyInfo `xml:"urn:xmldsig KeyInfo"`
}
// KeyInfo represents key information
type KeyInfo struct {
X509Data X509Data `xml:"urn:xmldsig X509Data"`
}
// X509Data represents X509 certificate data
type X509Data struct {
X509Certificate string `xml:"urn:xmldsig X509Certificate"`
}
// SingleSignOnService represents SSO service endpoint
type SingleSignOnService struct {
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
}
// SingleLogoutService represents SLO service endpoint
type SingleLogoutService struct {
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
}
// SAMLRequest represents a SAML authentication request
type SAMLRequest struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol AuthnRequest"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Destination string `xml:"Destination,attr"`
AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"`
ProtocolBinding string `xml:"ProtocolBinding,attr"`
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
NameIDPolicy NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
}
// Issuer represents the SAML issuer
type Issuer struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
Value string `xml:",chardata"`
}
// NameIDPolicy represents the name ID policy
type NameIDPolicy struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
Format string `xml:"Format,attr"`
}
// SAMLResponse represents a SAML response
type SAMLResponse struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Destination string `xml:"Destination,attr"`
InResponseTo string `xml:"InResponseTo,attr"`
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
Status Status `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
Assertion Assertion `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
}
// Status represents the SAML response status
type Status struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
StatusCode StatusCode `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
}
// StatusCode represents the status code
type StatusCode struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
Value string `xml:"Value,attr"`
}
// Assertion represents a SAML assertion
type Assertion struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
Subject Subject `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
Conditions Conditions `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
AttributeStatement AttributeStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
AuthnStatement AuthnStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
}
// Subject represents the assertion subject
type Subject struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
NameID NameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
SubjectConfirmation SubjectConfirmation `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
}
// NameID represents the name identifier
type NameID struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
Format string `xml:"Format,attr"`
Value string `xml:",chardata"`
}
// SubjectConfirmation represents subject confirmation
type SubjectConfirmation struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
Method string `xml:"Method,attr"`
SubjectConfirmationData SubjectConfirmationData `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
}
// SubjectConfirmationData represents subject confirmation data
type SubjectConfirmationData struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
InResponseTo string `xml:"InResponseTo,attr"`
NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
Recipient string `xml:"Recipient,attr"`
}
// Conditions represents assertion conditions
type Conditions struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
NotBefore time.Time `xml:"NotBefore,attr"`
NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
AudienceRestriction AudienceRestriction `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
}
// AudienceRestriction represents audience restriction
type AudienceRestriction struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
Audience Audience `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
}
// Audience represents the intended audience
type Audience struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
Value string `xml:",chardata"`
}
// AttributeStatement represents attribute statement
type AttributeStatement struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
Attribute []Attribute `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
}
// Attribute represents a SAML attribute
type Attribute struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
Name string `xml:"Name,attr"`
AttributeValue []AttributeValue `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
}
// AttributeValue represents an attribute value
type AttributeValue struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
Type string `xml:"http://www.w3.org/2001/XMLSchema-instance type,attr"`
Value string `xml:",chardata"`
}
// AuthnStatement represents authentication statement
type AuthnStatement struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
AuthnInstant time.Time `xml:"AuthnInstant,attr"`
SessionIndex string `xml:"SessionIndex,attr"`
AuthnContext AuthnContext `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
}
// AuthnContext represents authentication context
type AuthnContext struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
AuthnContextClassRef string `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContextClassRef"`
}
// GetMetadata fetches the SAML IdP metadata
func (p *SAMLProvider) GetMetadata(ctx context.Context) (*SAMLMetadata, error) {
metadataURL := p.config.GetString("SAML_IDP_METADATA_URL")
if metadataURL == "" {
return nil, errors.NewConfigurationError("SAML_IDP_METADATA_URL not configured")
}
p.logger.Debug("Fetching SAML IdP metadata", zap.String("url", metadataURL))
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
if err != nil {
return nil, errors.NewInternalError("Failed to create metadata request").WithInternal(err)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to fetch IdP metadata").WithInternal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errors.NewInternalError(fmt.Sprintf("Metadata endpoint returned status %d", resp.StatusCode))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read metadata response").WithInternal(err)
}
var metadata SAMLMetadata
if err := xml.Unmarshal(body, &metadata); err != nil {
return nil, errors.NewInternalError("Failed to parse SAML metadata").WithInternal(err)
}
p.logger.Debug("SAML IdP metadata fetched successfully",
zap.String("entity_id", metadata.EntityID))
return &metadata, nil
}
// GenerateAuthRequest generates a SAML authentication request
func (p *SAMLProvider) GenerateAuthRequest(ctx context.Context, relayState string) (string, string, error) {
metadata, err := p.GetMetadata(ctx)
if err != nil {
return "", "", err
}
// Find SSO endpoint
var ssoEndpoint string
for _, sso := range metadata.IDPSSODescriptor.SingleSignOnService {
if sso.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" {
ssoEndpoint = sso.Location
break
}
}
if ssoEndpoint == "" {
return "", "", errors.NewConfigurationError("No HTTP-Redirect SSO endpoint found in IdP metadata")
}
// Generate request ID
requestID := "_" + uuid.New().String()
// Get SP configuration
spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
acsURL := p.config.GetString("SAML_SP_ACS_URL")
if spEntityID == "" {
return "", "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
}
if acsURL == "" {
return "", "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
}
// Create SAML request
samlRequest := SAMLRequest{
ID: requestID,
Version: "2.0",
IssueInstant: time.Now().UTC(),
Destination: ssoEndpoint,
AssertionConsumerServiceURL: acsURL,
ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
Issuer: Issuer{
Value: spEntityID,
},
NameIDPolicy: NameIDPolicy{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress",
},
}
// Marshal to XML
xmlData, err := xml.MarshalIndent(samlRequest, "", " ")
if err != nil {
return "", "", errors.NewInternalError("Failed to marshal SAML request").WithInternal(err)
}
// Add XML declaration
xmlRequest := `<?xml version="1.0" encoding="UTF-8"?>` + "\n" + string(xmlData)
// Base64 encode and URL encode
encodedRequest := base64.StdEncoding.EncodeToString([]byte(xmlRequest))
// Build redirect URL
params := url.Values{
"SAMLRequest": {encodedRequest},
"RelayState": {relayState},
}
redirectURL := ssoEndpoint + "?" + params.Encode()
p.logger.Debug("Generated SAML authentication request",
zap.String("request_id", requestID),
zap.String("sso_endpoint", ssoEndpoint))
return redirectURL, requestID, nil
}
// ProcessSAMLResponse processes a SAML response and extracts user information
func (p *SAMLProvider) ProcessSAMLResponse(ctx context.Context, samlResponse string, expectedRequestID string) (*domain.AuthContext, error) {
p.logger.Debug("Processing SAML response")
// Base64 decode the response
decodedResponse, err := base64.StdEncoding.DecodeString(samlResponse)
if err != nil {
return nil, errors.NewValidationError("Failed to decode SAML response").WithInternal(err)
}
// Parse XML
var response SAMLResponse
if err := xml.Unmarshal(decodedResponse, &response); err != nil {
return nil, errors.NewValidationError("Failed to parse SAML response").WithInternal(err)
}
// Validate response
if err := p.validateSAMLResponse(&response, expectedRequestID); err != nil {
return nil, err
}
// Extract user information from assertion
authContext, err := p.extractUserInfo(&response.Assertion)
if err != nil {
return nil, err
}
p.logger.Debug("SAML response processed successfully",
zap.String("user_id", authContext.UserID))
return authContext, nil
}
// validateSAMLResponse validates a SAML response
func (p *SAMLProvider) validateSAMLResponse(response *SAMLResponse, expectedRequestID string) error {
// Check status
if response.Status.StatusCode.Value != "urn:oasis:names:tc:SAML:2.0:status:Success" {
return errors.NewAuthenticationError("SAML authentication failed: " + response.Status.StatusCode.Value)
}
// Validate InResponseTo
if expectedRequestID != "" && response.InResponseTo != expectedRequestID {
return errors.NewValidationError("SAML response InResponseTo does not match request ID")
}
// Validate assertion conditions
assertion := &response.Assertion
now := time.Now().UTC()
if now.Before(assertion.Conditions.NotBefore) {
return errors.NewValidationError("SAML assertion not yet valid")
}
if now.After(assertion.Conditions.NotOnOrAfter) {
return errors.NewValidationError("SAML assertion has expired")
}
// Validate audience
expectedAudience := p.config.GetString("SAML_SP_ENTITY_ID")
if assertion.Conditions.AudienceRestriction.Audience.Value != expectedAudience {
return errors.NewValidationError("SAML assertion audience mismatch")
}
// In production, you should also validate the signature
// This requires implementing XML signature validation
return nil
}
// extractUserInfo extracts user information from SAML assertion
func (p *SAMLProvider) extractUserInfo(assertion *Assertion) (*domain.AuthContext, error) {
// Extract user ID from NameID
userID := assertion.Subject.NameID.Value
if userID == "" {
return nil, errors.NewValidationError("SAML assertion missing NameID")
}
// Extract attributes
claims := make(map[string]string)
claims["sub"] = userID
claims["name_id_format"] = assertion.Subject.NameID.Format
// Process attribute statements
for _, attr := range assertion.AttributeStatement.Attribute {
if len(attr.AttributeValue) > 0 {
// Use the first value if multiple values exist
claims[attr.Name] = attr.AttributeValue[0].Value
}
}
// Map common attributes to standard claims
if email, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"]; exists {
claims["email"] = email
}
if name, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"]; exists {
claims["name"] = name
}
if givenName, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname"]; exists {
claims["given_name"] = givenName
}
if surname, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname"]; exists {
claims["family_name"] = surname
}
// Extract permissions/roles if available
var permissions []string
if roles, exists := claims["http://schemas.microsoft.com/ws/2008/06/identity/claims/role"]; exists {
permissions = strings.Split(roles, ",")
}
authContext := &domain.AuthContext{
UserID: userID,
TokenType: domain.TokenTypeUser,
Claims: claims,
Permissions: permissions,
}
return authContext, nil
}
// GenerateServiceProviderMetadata generates SP metadata XML
func (p *SAMLProvider) GenerateServiceProviderMetadata() (string, error) {
spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
acsURL := p.config.GetString("SAML_SP_ACS_URL")
if spEntityID == "" {
return "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
}
if acsURL == "" {
return "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
}
// This is a simplified SP metadata generation
// In production, you should use a proper SAML library
metadata := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="%s">
<md:SPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="%s" index="0"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>`, spEntityID, acsURL)
return metadata, nil
}
// loadCredentials loads SP private key and certificate
func (p *SAMLProvider) loadCredentials() error {
// Load private key if configured
privateKeyPEM := p.config.GetString("SAML_SP_PRIVATE_KEY")
if privateKeyPEM != "" {
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return errors.NewConfigurationError("Failed to decode SAML SP private key")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
// Try PKCS8 format
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return errors.NewConfigurationError("Failed to parse SAML SP private key").WithInternal(err)
}
var ok bool
privateKey, ok = key.(*rsa.PrivateKey)
if !ok {
return errors.NewConfigurationError("SAML SP private key is not RSA")
}
}
p.privateKey = privateKey
}
// Load certificate if configured
certificatePEM := p.config.GetString("SAML_SP_CERTIFICATE")
if certificatePEM != "" {
block, _ := pem.Decode([]byte(certificatePEM))
if block == nil {
return errors.NewConfigurationError("Failed to decode SAML SP certificate")
}
certificate, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return errors.NewConfigurationError("Failed to parse SAML SP certificate").WithInternal(err)
}
p.certificate = certificate
}
return nil
}