package handlers import ( "context" "crypto/rand" "encoding/base64" "encoding/json" "net/http" "time" "github.com/gorilla/mux" "go.uber.org/zap" "github.com/kms/api-key-service/internal/auth" "github.com/kms/api-key-service/internal/config" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/errors" "github.com/kms/api-key-service/internal/services" ) // OAuth2Handler handles OAuth2/OIDC authentication flows type OAuth2Handler struct { config config.ConfigProvider logger *zap.Logger oauth2Provider *auth.OAuth2Provider authService services.AuthenticationService } // NewOAuth2Handler creates a new OAuth2 handler func NewOAuth2Handler( config config.ConfigProvider, logger *zap.Logger, authService services.AuthenticationService, ) *OAuth2Handler { oauth2Provider := auth.NewOAuth2Provider(config, logger) return &OAuth2Handler{ config: config, logger: logger, oauth2Provider: oauth2Provider, authService: authService, } } // AuthorizeRequest represents the OAuth2 authorization request type AuthorizeRequest struct { RedirectURI string `json:"redirect_uri" validate:"required,url"` State string `json:"state,omitempty"` } // AuthorizeResponse represents the OAuth2 authorization response type AuthorizeResponse struct { AuthURL string `json:"auth_url"` State string `json:"state"` CodeVerifier string `json:"code_verifier"` // In production, this should be stored securely } // CallbackRequest represents the OAuth2 callback request type CallbackRequest struct { Code string `json:"code" validate:"required"` State string `json:"state,omitempty"` RedirectURI string `json:"redirect_uri" validate:"required,url"` CodeVerifier string `json:"code_verifier" validate:"required"` } // CallbackResponse represents the OAuth2 callback response type CallbackResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` UserInfo *auth.UserInfo `json:"user_info"` JWTToken string `json:"jwt_token"` } // RefreshRequest represents the token refresh request type RefreshRequest struct { RefreshToken string `json:"refresh_token" validate:"required"` } // RefreshResponse represents the token refresh response type RefreshResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` JWTToken string `json:"jwt_token"` } // RegisterRoutes registers OAuth2 routes func (h *OAuth2Handler) RegisterRoutes(router *mux.Router) { oauth2Router := router.PathPrefix("/oauth2").Subrouter() oauth2Router.HandleFunc("/authorize", h.Authorize).Methods("POST") oauth2Router.HandleFunc("/callback", h.Callback).Methods("POST") oauth2Router.HandleFunc("/refresh", h.Refresh).Methods("POST") oauth2Router.HandleFunc("/userinfo", h.GetUserInfo).Methods("GET") } // Authorize initiates the OAuth2 authorization flow func (h *OAuth2Handler) Authorize(w http.ResponseWriter, r *http.Request) { ctx := r.Context() h.logger.Debug("Processing OAuth2 authorization request") var req AuthorizeRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { h.logger.Warn("Invalid authorization request", zap.Error(err)) http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Generate state if not provided if req.State == "" { state, err := h.generateState() if err != nil { h.logger.Error("Failed to generate state", zap.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } req.State = state } // Generate authorization URL authURL, err := h.oauth2Provider.GenerateAuthURL(ctx, req.State, req.RedirectURI) if err != nil { h.logger.Error("Failed to generate authorization URL", zap.Error(err)) if appErr, ok := err.(*errors.AppError); ok { http.Error(w, appErr.Message, appErr.StatusCode) return } http.Error(w, "Failed to generate authorization URL", http.StatusInternalServerError) return } // In production, store the code verifier securely (e.g., in session or cache) // For now, we'll return it in the response codeVerifier, err := h.generateCodeVerifier() if err != nil { h.logger.Error("Failed to generate code verifier", zap.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } response := AuthorizeResponse{ AuthURL: authURL, State: req.State, CodeVerifier: codeVerifier, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { h.logger.Error("Failed to encode authorization response", zap.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } h.logger.Debug("Authorization URL generated successfully", zap.String("state", req.State), zap.String("redirect_uri", req.RedirectURI)) } // Callback handles the OAuth2 callback and exchanges code for tokens func (h *OAuth2Handler) Callback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() h.logger.Debug("Processing OAuth2 callback") var req CallbackRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { h.logger.Warn("Invalid callback request", zap.Error(err)) http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Exchange authorization code for tokens tokenResp, err := h.oauth2Provider.ExchangeCodeForToken(ctx, req.Code, req.RedirectURI, req.CodeVerifier) if err != nil { h.logger.Error("Failed to exchange code for token", zap.Error(err)) if appErr, ok := err.(*errors.AppError); ok { http.Error(w, appErr.Message, appErr.StatusCode) return } http.Error(w, "Failed to exchange authorization code", http.StatusInternalServerError) return } // Get user information userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken) if err != nil { h.logger.Error("Failed to get user info", zap.Error(err)) if appErr, ok := err.(*errors.AppError); ok { http.Error(w, appErr.Message, appErr.StatusCode) return } http.Error(w, "Failed to get user information", http.StatusInternalServerError) return } // Generate internal JWT token for the user jwtToken, err := h.generateInternalJWTToken(ctx, userInfo) if err != nil { h.logger.Error("Failed to generate internal JWT token", zap.Error(err)) http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError) return } response := CallbackResponse{ AccessToken: tokenResp.AccessToken, TokenType: tokenResp.TokenType, ExpiresIn: tokenResp.ExpiresIn, RefreshToken: tokenResp.RefreshToken, UserInfo: userInfo, JWTToken: jwtToken, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { h.logger.Error("Failed to encode callback response", zap.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } h.logger.Info("OAuth2 callback processed successfully", zap.String("user_id", userInfo.Sub), zap.String("email", userInfo.Email)) } // Refresh refreshes an access token using refresh token func (h *OAuth2Handler) Refresh(w http.ResponseWriter, r *http.Request) { ctx := r.Context() h.logger.Debug("Processing token refresh request") var req RefreshRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { h.logger.Warn("Invalid refresh request", zap.Error(err)) http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Refresh the access token tokenResp, err := h.oauth2Provider.RefreshAccessToken(ctx, req.RefreshToken) if err != nil { h.logger.Error("Failed to refresh access token", zap.Error(err)) if appErr, ok := err.(*errors.AppError); ok { http.Error(w, appErr.Message, appErr.StatusCode) return } http.Error(w, "Failed to refresh access token", http.StatusInternalServerError) return } // Get updated user information userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken) if err != nil { h.logger.Error("Failed to get user info during refresh", zap.Error(err)) if appErr, ok := err.(*errors.AppError); ok { http.Error(w, appErr.Message, appErr.StatusCode) return } http.Error(w, "Failed to get user information", http.StatusInternalServerError) return } // Generate new internal JWT token jwtToken, err := h.generateInternalJWTToken(ctx, userInfo) if err != nil { h.logger.Error("Failed to generate internal JWT token during refresh", zap.Error(err)) http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError) return } response := RefreshResponse{ AccessToken: tokenResp.AccessToken, TokenType: tokenResp.TokenType, ExpiresIn: tokenResp.ExpiresIn, RefreshToken: tokenResp.RefreshToken, JWTToken: jwtToken, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { h.logger.Error("Failed to encode refresh response", zap.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } h.logger.Debug("Token refresh completed successfully", zap.String("user_id", userInfo.Sub)) } // GetUserInfo retrieves user information from the current session func (h *OAuth2Handler) GetUserInfo(w http.ResponseWriter, r *http.Request) { ctx := r.Context() h.logger.Debug("Processing user info request") // Extract JWT token from Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Authorization header required", http.StatusUnauthorized) return } // Remove "Bearer " prefix tokenString := authHeader if len(authHeader) > 7 && authHeader[:7] == "Bearer " { tokenString = authHeader[7:] } // Validate JWT token authContext, err := h.authService.ValidateJWTToken(ctx, tokenString) if err != nil { h.logger.Warn("Invalid JWT token in user info request", zap.Error(err)) http.Error(w, "Invalid or expired token", http.StatusUnauthorized) return } // Return user information from JWT claims userInfo := map[string]interface{}{ "sub": authContext.UserID, "email": authContext.Claims["email"], "name": authContext.Claims["name"], "permissions": authContext.Permissions, "app_id": authContext.AppID, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(userInfo); err != nil { h.logger.Error("Failed to encode user info response", zap.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } h.logger.Debug("User info request completed successfully", zap.String("user_id", authContext.UserID)) } // generateState generates a random state parameter for OAuth2 func (h *OAuth2Handler) generateState() (string, error) { bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", err } return base64.URLEncoding.EncodeToString(bytes), nil } // generateCodeVerifier generates a PKCE code verifier func (h *OAuth2Handler) generateCodeVerifier() (string, error) { bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(bytes), nil } // generateInternalJWTToken generates an internal JWT token for authenticated users func (h *OAuth2Handler) generateInternalJWTToken(ctx context.Context, userInfo *auth.UserInfo) (string, error) { // Create user token with information from OAuth2 provider userToken := &domain.UserToken{ AppID: h.config.GetString("INTERNAL_APP_ID"), UserID: userInfo.Sub, Permissions: []string{"read", "write"}, // Default permissions, should be based on user roles IssuedAt: time.Now(), ExpiresAt: time.Now().Add(24 * time.Hour), // 24 hour expiration MaxValidAt: time.Now().Add(7 * 24 * time.Hour), // 7 days max validity TokenType: domain.TokenTypeUser, Claims: map[string]string{ "sub": userInfo.Sub, "email": userInfo.Email, "name": userInfo.Name, "email_verified": func() string { if userInfo.EmailVerified { return "true" } return "false" }(), }, } // Generate JWT token using authentication service return h.authService.GenerateJWTToken(ctx, userToken) }