package auth import ( "context" "crypto/rand" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "go.uber.org/zap" "github.com/kms/api-key-service/internal/config" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/errors" ) // OAuth2Provider represents an OAuth2/OIDC provider type OAuth2Provider struct { config config.ConfigProvider logger *zap.Logger httpClient *http.Client } // NewOAuth2Provider creates a new OAuth2 provider func NewOAuth2Provider(config config.ConfigProvider, logger *zap.Logger) *OAuth2Provider { return &OAuth2Provider{ config: config, logger: logger, httpClient: &http.Client{ Timeout: 30 * time.Second, }, } } // OIDCDiscoveryDocument represents the OIDC discovery document type OIDCDiscoveryDocument struct { Issuer string `json:"issuer"` AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` UserInfoEndpoint string `json:"userinfo_endpoint"` JWKSUri string `json:"jwks_uri"` ScopesSupported []string `json:"scopes_supported"` ResponseTypesSupported []string `json:"response_types_supported"` GrantTypesSupported []string `json:"grant_types_supported"` } // TokenResponse represents the OAuth2 token response type TokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` IDToken string `json:"id_token,omitempty"` Scope string `json:"scope,omitempty"` } // UserInfo represents user information from the provider type UserInfo struct { Sub string `json:"sub"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` Name string `json:"name"` GivenName string `json:"given_name"` FamilyName string `json:"family_name"` Picture string `json:"picture"` PreferredUsername string `json:"preferred_username"` } // GetDiscoveryDocument fetches the OIDC discovery document func (p *OAuth2Provider) GetDiscoveryDocument(ctx context.Context) (*OIDCDiscoveryDocument, error) { providerURL := p.config.GetString("SSO_PROVIDER_URL") if providerURL == "" { return nil, errors.NewConfigurationError("SSO_PROVIDER_URL not configured") } // Construct discovery URL discoveryURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid_configuration" p.logger.Debug("Fetching OIDC discovery document", zap.String("url", discoveryURL)) req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil) if err != nil { return nil, errors.NewInternalError("Failed to create discovery request").WithInternal(err) } resp, err := p.httpClient.Do(req) if err != nil { return nil, errors.NewInternalError("Failed to fetch discovery document").WithInternal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, errors.NewInternalError(fmt.Sprintf("Discovery endpoint returned status %d", resp.StatusCode)) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.NewInternalError("Failed to read discovery response").WithInternal(err) } var discovery OIDCDiscoveryDocument if err := json.Unmarshal(body, &discovery); err != nil { return nil, errors.NewInternalError("Failed to parse discovery document").WithInternal(err) } p.logger.Debug("OIDC discovery document fetched successfully", zap.String("issuer", discovery.Issuer), zap.String("auth_endpoint", discovery.AuthorizationEndpoint), zap.String("token_endpoint", discovery.TokenEndpoint)) return &discovery, nil } // GenerateAuthURL generates the OAuth2 authorization URL func (p *OAuth2Provider) GenerateAuthURL(ctx context.Context, state, redirectURI string) (string, error) { discovery, err := p.GetDiscoveryDocument(ctx) if err != nil { return "", err } clientID := p.config.GetString("SSO_CLIENT_ID") if clientID == "" { return "", errors.NewConfigurationError("SSO_CLIENT_ID not configured") } // Generate PKCE code verifier and challenge codeVerifier, err := p.generateCodeVerifier() if err != nil { return "", errors.NewInternalError("Failed to generate PKCE code verifier").WithInternal(err) } codeChallenge := p.generateCodeChallenge(codeVerifier) // Build authorization URL params := url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {redirectURI}, "scope": {"openid profile email"}, "state": {state}, "code_challenge": {codeChallenge}, "code_challenge_method": {"S256"}, } authURL := discovery.AuthorizationEndpoint + "?" + params.Encode() p.logger.Debug("Generated OAuth2 authorization URL", zap.String("client_id", clientID), zap.String("redirect_uri", redirectURI), zap.String("state", state)) // Store code verifier for later use (in production, this should be stored in a secure session store) // For now, we'll return it as part of the response or store it in cache return authURL, nil } // ExchangeCodeForToken exchanges authorization code for access token func (p *OAuth2Provider) ExchangeCodeForToken(ctx context.Context, code, redirectURI, codeVerifier string) (*TokenResponse, error) { discovery, err := p.GetDiscoveryDocument(ctx) if err != nil { return nil, err } clientID := p.config.GetString("SSO_CLIENT_ID") clientSecret := p.config.GetString("SSO_CLIENT_SECRET") if clientID == "" { return nil, errors.NewConfigurationError("SSO_CLIENT_ID not configured") } if clientSecret == "" { return nil, errors.NewConfigurationError("SSO_CLIENT_SECRET not configured") } // Prepare token exchange request data := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, "redirect_uri": {redirectURI}, "client_id": {clientID}, "client_secret": {clientSecret}, "code_verifier": {codeVerifier}, } p.logger.Debug("Exchanging authorization code for token", zap.String("token_endpoint", discovery.TokenEndpoint), zap.String("client_id", clientID)) req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode())) if err != nil { return nil, errors.NewInternalError("Failed to create token request").WithInternal(err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") resp, err := p.httpClient.Do(req) if err != nil { return nil, errors.NewInternalError("Failed to exchange code for token").WithInternal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.NewInternalError("Failed to read token response").WithInternal(err) } if resp.StatusCode != http.StatusOK { p.logger.Error("Token exchange failed", zap.Int("status_code", resp.StatusCode), zap.String("response", string(body))) return nil, errors.NewAuthenticationError("Failed to exchange authorization code") } var tokenResp TokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, errors.NewInternalError("Failed to parse token response").WithInternal(err) } p.logger.Debug("Successfully exchanged code for token", zap.String("token_type", tokenResp.TokenType), zap.Int("expires_in", tokenResp.ExpiresIn)) return &tokenResp, nil } // GetUserInfo retrieves user information using the access token func (p *OAuth2Provider) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) { discovery, err := p.GetDiscoveryDocument(ctx) if err != nil { return nil, err } if discovery.UserInfoEndpoint == "" { return nil, errors.NewConfigurationError("UserInfo endpoint not available") } p.logger.Debug("Fetching user info", zap.String("endpoint", discovery.UserInfoEndpoint)) req, err := http.NewRequestWithContext(ctx, "GET", discovery.UserInfoEndpoint, nil) if err != nil { return nil, errors.NewInternalError("Failed to create userinfo request").WithInternal(err) } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Accept", "application/json") resp, err := p.httpClient.Do(req) if err != nil { return nil, errors.NewInternalError("Failed to fetch user info").WithInternal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { p.logger.Error("UserInfo request failed", zap.Int("status_code", resp.StatusCode)) return nil, errors.NewAuthenticationError("Failed to fetch user information") } body, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.NewInternalError("Failed to read userinfo response").WithInternal(err) } var userInfo UserInfo if err := json.Unmarshal(body, &userInfo); err != nil { return nil, errors.NewInternalError("Failed to parse user info").WithInternal(err) } p.logger.Debug("Successfully fetched user info", zap.String("sub", userInfo.Sub), zap.String("email", userInfo.Email), zap.String("name", userInfo.Name)) return &userInfo, nil } // ValidateIDToken validates an OIDC ID token (basic validation) func (p *OAuth2Provider) ValidateIDToken(ctx context.Context, idToken string) (*domain.AuthContext, error) { // This is a simplified implementation // In production, you should validate the JWT signature using the provider's JWKS p.logger.Debug("Validating ID token") // For now, we'll just decode the token without signature verification // This should be replaced with proper JWT validation using the provider's public keys parts := strings.Split(idToken, ".") if len(parts) != 3 { return nil, errors.NewValidationError("Invalid ID token format") } // Decode payload (second part) payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, errors.NewValidationError("Failed to decode ID token payload").WithInternal(err) } var claims map[string]interface{} if err := json.Unmarshal(payload, &claims); err != nil { return nil, errors.NewValidationError("Failed to parse ID token claims").WithInternal(err) } // Extract basic claims sub, _ := claims["sub"].(string) email, _ := claims["email"].(string) name, _ := claims["name"].(string) if sub == "" { return nil, errors.NewValidationError("ID token missing subject claim") } authContext := &domain.AuthContext{ UserID: sub, TokenType: domain.TokenTypeUser, Claims: map[string]string{ "sub": sub, "email": email, "name": name, }, Permissions: []string{}, // Will be populated based on user roles/groups } p.logger.Debug("ID token validated successfully", zap.String("sub", sub), zap.String("email", email)) return authContext, nil } // generateCodeVerifier generates a PKCE code verifier func (p *OAuth2Provider) generateCodeVerifier() (string, error) { bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(bytes), nil } // generateCodeChallenge generates a PKCE code challenge from verifier func (p *OAuth2Provider) generateCodeChallenge(verifier string) string { // For S256 method, we would hash the verifier with SHA256 // For simplicity, we'll use the verifier as-is (plain method) // In production, implement proper S256 challenge generation return verifier } // RefreshAccessToken refreshes an access token using refresh token func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { discovery, err := p.GetDiscoveryDocument(ctx) if err != nil { return nil, err } clientID := p.config.GetString("SSO_CLIENT_ID") clientSecret := p.config.GetString("SSO_CLIENT_SECRET") data := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, "client_id": {clientID}, "client_secret": {clientSecret}, } p.logger.Debug("Refreshing access token") req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode())) if err != nil { return nil, errors.NewInternalError("Failed to create refresh request").WithInternal(err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") resp, err := p.httpClient.Do(req) if err != nil { return nil, errors.NewInternalError("Failed to refresh token").WithInternal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.NewInternalError("Failed to read refresh response").WithInternal(err) } if resp.StatusCode != http.StatusOK { p.logger.Error("Token refresh failed", zap.Int("status_code", resp.StatusCode), zap.String("response", string(body))) return nil, errors.NewAuthenticationError("Failed to refresh access token") } var tokenResp TokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, errors.NewInternalError("Failed to parse refresh response").WithInternal(err) } p.logger.Debug("Successfully refreshed access token") return &tokenResp, nil }