406 lines
13 KiB
Go
406 lines
13 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/kms/api-key-service/internal/config"
|
|
"github.com/kms/api-key-service/internal/domain"
|
|
"github.com/kms/api-key-service/internal/errors"
|
|
)
|
|
|
|
// OAuth2Provider represents an OAuth2/OIDC provider
|
|
type OAuth2Provider struct {
|
|
config config.ConfigProvider
|
|
logger *zap.Logger
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewOAuth2Provider creates a new OAuth2 provider
|
|
func NewOAuth2Provider(config config.ConfigProvider, logger *zap.Logger) *OAuth2Provider {
|
|
return &OAuth2Provider{
|
|
config: config,
|
|
logger: logger,
|
|
httpClient: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
// OIDCDiscoveryDocument represents the OIDC discovery document
|
|
type OIDCDiscoveryDocument struct {
|
|
Issuer string `json:"issuer"`
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
|
JWKSUri string `json:"jwks_uri"`
|
|
ScopesSupported []string `json:"scopes_supported"`
|
|
ResponseTypesSupported []string `json:"response_types_supported"`
|
|
GrantTypesSupported []string `json:"grant_types_supported"`
|
|
}
|
|
|
|
// TokenResponse represents the OAuth2 token response
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
IDToken string `json:"id_token,omitempty"`
|
|
Scope string `json:"scope,omitempty"`
|
|
}
|
|
|
|
// UserInfo represents user information from the provider
|
|
type UserInfo struct {
|
|
Sub string `json:"sub"`
|
|
Email string `json:"email"`
|
|
EmailVerified bool `json:"email_verified"`
|
|
Name string `json:"name"`
|
|
GivenName string `json:"given_name"`
|
|
FamilyName string `json:"family_name"`
|
|
Picture string `json:"picture"`
|
|
PreferredUsername string `json:"preferred_username"`
|
|
}
|
|
|
|
// GetDiscoveryDocument fetches the OIDC discovery document
|
|
func (p *OAuth2Provider) GetDiscoveryDocument(ctx context.Context) (*OIDCDiscoveryDocument, error) {
|
|
providerURL := p.config.GetString("SSO_PROVIDER_URL")
|
|
if providerURL == "" {
|
|
return nil, errors.NewConfigurationError("SSO_PROVIDER_URL not configured")
|
|
}
|
|
|
|
// Construct discovery URL
|
|
discoveryURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid_configuration"
|
|
|
|
p.logger.Debug("Fetching OIDC discovery document", zap.String("url", discoveryURL))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to create discovery request").WithInternal(err)
|
|
}
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to fetch discovery document").WithInternal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, errors.NewInternalError(fmt.Sprintf("Discovery endpoint returned status %d", resp.StatusCode))
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to read discovery response").WithInternal(err)
|
|
}
|
|
|
|
var discovery OIDCDiscoveryDocument
|
|
if err := json.Unmarshal(body, &discovery); err != nil {
|
|
return nil, errors.NewInternalError("Failed to parse discovery document").WithInternal(err)
|
|
}
|
|
|
|
p.logger.Debug("OIDC discovery document fetched successfully",
|
|
zap.String("issuer", discovery.Issuer),
|
|
zap.String("auth_endpoint", discovery.AuthorizationEndpoint),
|
|
zap.String("token_endpoint", discovery.TokenEndpoint))
|
|
|
|
return &discovery, nil
|
|
}
|
|
|
|
// GenerateAuthURL generates the OAuth2 authorization URL
|
|
func (p *OAuth2Provider) GenerateAuthURL(ctx context.Context, state, redirectURI string) (string, error) {
|
|
discovery, err := p.GetDiscoveryDocument(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
clientID := p.config.GetString("SSO_CLIENT_ID")
|
|
if clientID == "" {
|
|
return "", errors.NewConfigurationError("SSO_CLIENT_ID not configured")
|
|
}
|
|
|
|
// Generate PKCE code verifier and challenge
|
|
codeVerifier, err := p.generateCodeVerifier()
|
|
if err != nil {
|
|
return "", errors.NewInternalError("Failed to generate PKCE code verifier").WithInternal(err)
|
|
}
|
|
|
|
codeChallenge := p.generateCodeChallenge(codeVerifier)
|
|
|
|
// Build authorization URL
|
|
params := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {redirectURI},
|
|
"scope": {"openid profile email"},
|
|
"state": {state},
|
|
"code_challenge": {codeChallenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
|
|
authURL := discovery.AuthorizationEndpoint + "?" + params.Encode()
|
|
|
|
p.logger.Debug("Generated OAuth2 authorization URL",
|
|
zap.String("client_id", clientID),
|
|
zap.String("redirect_uri", redirectURI),
|
|
zap.String("state", state))
|
|
|
|
// Store code verifier for later use (in production, this should be stored in a secure session store)
|
|
// For now, we'll return it as part of the response or store it in cache
|
|
|
|
return authURL, nil
|
|
}
|
|
|
|
// ExchangeCodeForToken exchanges authorization code for access token
|
|
func (p *OAuth2Provider) ExchangeCodeForToken(ctx context.Context, code, redirectURI, codeVerifier string) (*TokenResponse, error) {
|
|
discovery, err := p.GetDiscoveryDocument(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
clientID := p.config.GetString("SSO_CLIENT_ID")
|
|
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
|
|
|
|
if clientID == "" {
|
|
return nil, errors.NewConfigurationError("SSO_CLIENT_ID not configured")
|
|
}
|
|
if clientSecret == "" {
|
|
return nil, errors.NewConfigurationError("SSO_CLIENT_SECRET not configured")
|
|
}
|
|
|
|
// Prepare token exchange request
|
|
data := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {code},
|
|
"redirect_uri": {redirectURI},
|
|
"client_id": {clientID},
|
|
"client_secret": {clientSecret},
|
|
"code_verifier": {codeVerifier},
|
|
}
|
|
|
|
p.logger.Debug("Exchanging authorization code for token",
|
|
zap.String("token_endpoint", discovery.TokenEndpoint),
|
|
zap.String("client_id", clientID))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to create token request").WithInternal(err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to exchange code for token").WithInternal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to read token response").WithInternal(err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
p.logger.Error("Token exchange failed",
|
|
zap.Int("status_code", resp.StatusCode),
|
|
zap.String("response", string(body)))
|
|
return nil, errors.NewAuthenticationError("Failed to exchange authorization code")
|
|
}
|
|
|
|
var tokenResp TokenResponse
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, errors.NewInternalError("Failed to parse token response").WithInternal(err)
|
|
}
|
|
|
|
p.logger.Debug("Successfully exchanged code for token",
|
|
zap.String("token_type", tokenResp.TokenType),
|
|
zap.Int("expires_in", tokenResp.ExpiresIn))
|
|
|
|
return &tokenResp, nil
|
|
}
|
|
|
|
// GetUserInfo retrieves user information using the access token
|
|
func (p *OAuth2Provider) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
|
|
discovery, err := p.GetDiscoveryDocument(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if discovery.UserInfoEndpoint == "" {
|
|
return nil, errors.NewConfigurationError("UserInfo endpoint not available")
|
|
}
|
|
|
|
p.logger.Debug("Fetching user info", zap.String("endpoint", discovery.UserInfoEndpoint))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", discovery.UserInfoEndpoint, nil)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to create userinfo request").WithInternal(err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to fetch user info").WithInternal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
p.logger.Error("UserInfo request failed", zap.Int("status_code", resp.StatusCode))
|
|
return nil, errors.NewAuthenticationError("Failed to fetch user information")
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to read userinfo response").WithInternal(err)
|
|
}
|
|
|
|
var userInfo UserInfo
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, errors.NewInternalError("Failed to parse user info").WithInternal(err)
|
|
}
|
|
|
|
p.logger.Debug("Successfully fetched user info",
|
|
zap.String("sub", userInfo.Sub),
|
|
zap.String("email", userInfo.Email),
|
|
zap.String("name", userInfo.Name))
|
|
|
|
return &userInfo, nil
|
|
}
|
|
|
|
// ValidateIDToken validates an OIDC ID token (basic validation)
|
|
func (p *OAuth2Provider) ValidateIDToken(ctx context.Context, idToken string) (*domain.AuthContext, error) {
|
|
// This is a simplified implementation
|
|
// In production, you should validate the JWT signature using the provider's JWKS
|
|
|
|
p.logger.Debug("Validating ID token")
|
|
|
|
// For now, we'll just decode the token without signature verification
|
|
// This should be replaced with proper JWT validation using the provider's public keys
|
|
|
|
parts := strings.Split(idToken, ".")
|
|
if len(parts) != 3 {
|
|
return nil, errors.NewValidationError("Invalid ID token format")
|
|
}
|
|
|
|
// Decode payload (second part)
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, errors.NewValidationError("Failed to decode ID token payload").WithInternal(err)
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
|
return nil, errors.NewValidationError("Failed to parse ID token claims").WithInternal(err)
|
|
}
|
|
|
|
// Extract basic claims
|
|
sub, _ := claims["sub"].(string)
|
|
email, _ := claims["email"].(string)
|
|
name, _ := claims["name"].(string)
|
|
|
|
if sub == "" {
|
|
return nil, errors.NewValidationError("ID token missing subject claim")
|
|
}
|
|
|
|
authContext := &domain.AuthContext{
|
|
UserID: sub,
|
|
TokenType: domain.TokenTypeUser,
|
|
Claims: map[string]string{
|
|
"sub": sub,
|
|
"email": email,
|
|
"name": name,
|
|
},
|
|
Permissions: []string{}, // Will be populated based on user roles/groups
|
|
}
|
|
|
|
p.logger.Debug("ID token validated successfully",
|
|
zap.String("sub", sub),
|
|
zap.String("email", email))
|
|
|
|
return authContext, nil
|
|
}
|
|
|
|
// generateCodeVerifier generates a PKCE code verifier
|
|
func (p *OAuth2Provider) generateCodeVerifier() (string, error) {
|
|
bytes := make([]byte, 32)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
|
}
|
|
|
|
// generateCodeChallenge generates a PKCE code challenge from verifier
|
|
func (p *OAuth2Provider) generateCodeChallenge(verifier string) string {
|
|
// For S256 method, we would hash the verifier with SHA256
|
|
// For simplicity, we'll use the verifier as-is (plain method)
|
|
// In production, implement proper S256 challenge generation
|
|
return verifier
|
|
}
|
|
|
|
// RefreshAccessToken refreshes an access token using refresh token
|
|
func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
|
discovery, err := p.GetDiscoveryDocument(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
clientID := p.config.GetString("SSO_CLIENT_ID")
|
|
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
|
|
|
|
data := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {refreshToken},
|
|
"client_id": {clientID},
|
|
"client_secret": {clientSecret},
|
|
}
|
|
|
|
p.logger.Debug("Refreshing access token")
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to create refresh request").WithInternal(err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to refresh token").WithInternal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, errors.NewInternalError("Failed to read refresh response").WithInternal(err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
p.logger.Error("Token refresh failed",
|
|
zap.Int("status_code", resp.StatusCode),
|
|
zap.String("response", string(body)))
|
|
return nil, errors.NewAuthenticationError("Failed to refresh access token")
|
|
}
|
|
|
|
var tokenResp TokenResponse
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, errors.NewInternalError("Failed to parse refresh response").WithInternal(err)
|
|
}
|
|
|
|
p.logger.Debug("Successfully refreshed access token")
|
|
|
|
return &tokenResp, nil
|
|
}
|