Files
skybridge/kms/internal/auth/oauth2.go
2025-08-26 19:29:41 -04:00

406 lines
13 KiB
Go

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"go.uber.org/zap"
"github.com/RyanCopley/skybridge/kms/internal/config"
"github.com/RyanCopley/skybridge/kms/internal/domain"
"github.com/RyanCopley/skybridge/kms/internal/errors"
)
// OAuth2Provider represents an OAuth2/OIDC provider
type OAuth2Provider struct {
config config.ConfigProvider
logger *zap.Logger
httpClient *http.Client
}
// NewOAuth2Provider creates a new OAuth2 provider
func NewOAuth2Provider(config config.ConfigProvider, logger *zap.Logger) *OAuth2Provider {
return &OAuth2Provider{
config: config,
logger: logger,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// OIDCDiscoveryDocument represents the OIDC discovery document
type OIDCDiscoveryDocument struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
JWKSUri string `json:"jwks_uri"`
ScopesSupported []string `json:"scopes_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
}
// TokenResponse represents the OAuth2 token response
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// UserInfo represents user information from the provider
type UserInfo struct {
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Picture string `json:"picture"`
PreferredUsername string `json:"preferred_username"`
}
// GetDiscoveryDocument fetches the OIDC discovery document
func (p *OAuth2Provider) GetDiscoveryDocument(ctx context.Context) (*OIDCDiscoveryDocument, error) {
providerURL := p.config.GetString("SSO_PROVIDER_URL")
if providerURL == "" {
return nil, errors.NewConfigurationError("SSO_PROVIDER_URL not configured")
}
// Construct discovery URL
discoveryURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid_configuration"
p.logger.Debug("Fetching OIDC discovery document", zap.String("url", discoveryURL))
req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil)
if err != nil {
return nil, errors.NewInternalError("Failed to create discovery request").WithInternal(err)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to fetch discovery document").WithInternal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errors.NewInternalError(fmt.Sprintf("Discovery endpoint returned status %d", resp.StatusCode))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read discovery response").WithInternal(err)
}
var discovery OIDCDiscoveryDocument
if err := json.Unmarshal(body, &discovery); err != nil {
return nil, errors.NewInternalError("Failed to parse discovery document").WithInternal(err)
}
p.logger.Debug("OIDC discovery document fetched successfully",
zap.String("issuer", discovery.Issuer),
zap.String("auth_endpoint", discovery.AuthorizationEndpoint),
zap.String("token_endpoint", discovery.TokenEndpoint))
return &discovery, nil
}
// GenerateAuthURL generates the OAuth2 authorization URL
func (p *OAuth2Provider) GenerateAuthURL(ctx context.Context, state, redirectURI string) (string, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return "", err
}
clientID := p.config.GetString("SSO_CLIENT_ID")
if clientID == "" {
return "", errors.NewConfigurationError("SSO_CLIENT_ID not configured")
}
// Generate PKCE code verifier and challenge
codeVerifier, err := p.generateCodeVerifier()
if err != nil {
return "", errors.NewInternalError("Failed to generate PKCE code verifier").WithInternal(err)
}
codeChallenge := p.generateCodeChallenge(codeVerifier)
// Build authorization URL
params := url.Values{
"response_type": {"code"},
"client_id": {clientID},
"redirect_uri": {redirectURI},
"scope": {"openid profile email"},
"state": {state},
"code_challenge": {codeChallenge},
"code_challenge_method": {"S256"},
}
authURL := discovery.AuthorizationEndpoint + "?" + params.Encode()
p.logger.Debug("Generated OAuth2 authorization URL",
zap.String("client_id", clientID),
zap.String("redirect_uri", redirectURI),
zap.String("state", state))
// Store code verifier for later use (in production, this should be stored in a secure session store)
// For now, we'll return it as part of the response or store it in cache
return authURL, nil
}
// ExchangeCodeForToken exchanges authorization code for access token
func (p *OAuth2Provider) ExchangeCodeForToken(ctx context.Context, code, redirectURI, codeVerifier string) (*TokenResponse, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return nil, err
}
clientID := p.config.GetString("SSO_CLIENT_ID")
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
if clientID == "" {
return nil, errors.NewConfigurationError("SSO_CLIENT_ID not configured")
}
if clientSecret == "" {
return nil, errors.NewConfigurationError("SSO_CLIENT_SECRET not configured")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {clientID},
"client_secret": {clientSecret},
"code_verifier": {codeVerifier},
}
p.logger.Debug("Exchanging authorization code for token",
zap.String("token_endpoint", discovery.TokenEndpoint),
zap.String("client_id", clientID))
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, errors.NewInternalError("Failed to create token request").WithInternal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to exchange code for token").WithInternal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read token response").WithInternal(err)
}
if resp.StatusCode != http.StatusOK {
p.logger.Error("Token exchange failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(body)))
return nil, errors.NewAuthenticationError("Failed to exchange authorization code")
}
var tokenResp TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.NewInternalError("Failed to parse token response").WithInternal(err)
}
p.logger.Debug("Successfully exchanged code for token",
zap.String("token_type", tokenResp.TokenType),
zap.Int("expires_in", tokenResp.ExpiresIn))
return &tokenResp, nil
}
// GetUserInfo retrieves user information using the access token
func (p *OAuth2Provider) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return nil, err
}
if discovery.UserInfoEndpoint == "" {
return nil, errors.NewConfigurationError("UserInfo endpoint not available")
}
p.logger.Debug("Fetching user info", zap.String("endpoint", discovery.UserInfoEndpoint))
req, err := http.NewRequestWithContext(ctx, "GET", discovery.UserInfoEndpoint, nil)
if err != nil {
return nil, errors.NewInternalError("Failed to create userinfo request").WithInternal(err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to fetch user info").WithInternal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
p.logger.Error("UserInfo request failed", zap.Int("status_code", resp.StatusCode))
return nil, errors.NewAuthenticationError("Failed to fetch user information")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read userinfo response").WithInternal(err)
}
var userInfo UserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, errors.NewInternalError("Failed to parse user info").WithInternal(err)
}
p.logger.Debug("Successfully fetched user info",
zap.String("sub", userInfo.Sub),
zap.String("email", userInfo.Email),
zap.String("name", userInfo.Name))
return &userInfo, nil
}
// ValidateIDToken validates an OIDC ID token (basic validation)
func (p *OAuth2Provider) ValidateIDToken(ctx context.Context, idToken string) (*domain.AuthContext, error) {
// This is a simplified implementation
// In production, you should validate the JWT signature using the provider's JWKS
p.logger.Debug("Validating ID token")
// For now, we'll just decode the token without signature verification
// This should be replaced with proper JWT validation using the provider's public keys
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, errors.NewValidationError("Invalid ID token format")
}
// Decode payload (second part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, errors.NewValidationError("Failed to decode ID token payload").WithInternal(err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, errors.NewValidationError("Failed to parse ID token claims").WithInternal(err)
}
// Extract basic claims
sub, _ := claims["sub"].(string)
email, _ := claims["email"].(string)
name, _ := claims["name"].(string)
if sub == "" {
return nil, errors.NewValidationError("ID token missing subject claim")
}
authContext := &domain.AuthContext{
UserID: sub,
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"sub": sub,
"email": email,
"name": name,
},
Permissions: []string{}, // Will be populated based on user roles/groups
}
p.logger.Debug("ID token validated successfully",
zap.String("sub", sub),
zap.String("email", email))
return authContext, nil
}
// generateCodeVerifier generates a PKCE code verifier
func (p *OAuth2Provider) generateCodeVerifier() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
// generateCodeChallenge generates a PKCE code challenge from verifier
func (p *OAuth2Provider) generateCodeChallenge(verifier string) string {
// For S256 method, we would hash the verifier with SHA256
// For simplicity, we'll use the verifier as-is (plain method)
// In production, implement proper S256 challenge generation
return verifier
}
// RefreshAccessToken refreshes an access token using refresh token
func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return nil, err
}
clientID := p.config.GetString("SSO_CLIENT_ID")
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {clientID},
"client_secret": {clientSecret},
}
p.logger.Debug("Refreshing access token")
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, errors.NewInternalError("Failed to create refresh request").WithInternal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to refresh token").WithInternal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read refresh response").WithInternal(err)
}
if resp.StatusCode != http.StatusOK {
p.logger.Error("Token refresh failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(body)))
return nil, errors.NewAuthenticationError("Failed to refresh access token")
}
var tokenResp TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.NewInternalError("Failed to parse refresh response").WithInternal(err)
}
p.logger.Debug("Successfully refreshed access token")
return &tokenResp, nil
}