Files
skybridge/internal/handlers/oauth2.go
2025-08-22 17:32:57 -04:00

395 lines
12 KiB
Go

package handlers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/services"
)
// OAuth2Handler handles OAuth2/OIDC authentication flows
type OAuth2Handler struct {
config config.ConfigProvider
logger *zap.Logger
oauth2Provider *auth.OAuth2Provider
authService services.AuthenticationService
}
// NewOAuth2Handler creates a new OAuth2 handler
func NewOAuth2Handler(
config config.ConfigProvider,
logger *zap.Logger,
authService services.AuthenticationService,
) *OAuth2Handler {
oauth2Provider := auth.NewOAuth2Provider(config, logger)
return &OAuth2Handler{
config: config,
logger: logger,
oauth2Provider: oauth2Provider,
authService: authService,
}
}
// AuthorizeRequest represents the OAuth2 authorization request
type AuthorizeRequest struct {
RedirectURI string `json:"redirect_uri" validate:"required,url"`
State string `json:"state,omitempty"`
}
// AuthorizeResponse represents the OAuth2 authorization response
type AuthorizeResponse struct {
AuthURL string `json:"auth_url"`
State string `json:"state"`
CodeVerifier string `json:"code_verifier"` // In production, this should be stored securely
}
// CallbackRequest represents the OAuth2 callback request
type CallbackRequest struct {
Code string `json:"code" validate:"required"`
State string `json:"state,omitempty"`
RedirectURI string `json:"redirect_uri" validate:"required,url"`
CodeVerifier string `json:"code_verifier" validate:"required"`
}
// CallbackResponse represents the OAuth2 callback response
type CallbackResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
UserInfo *auth.UserInfo `json:"user_info"`
JWTToken string `json:"jwt_token"`
}
// RefreshRequest represents the token refresh request
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" validate:"required"`
}
// RefreshResponse represents the token refresh response
type RefreshResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
JWTToken string `json:"jwt_token"`
}
// RegisterRoutes registers OAuth2 routes
func (h *OAuth2Handler) RegisterRoutes(router *mux.Router) {
oauth2Router := router.PathPrefix("/oauth2").Subrouter()
oauth2Router.HandleFunc("/authorize", h.Authorize).Methods("POST")
oauth2Router.HandleFunc("/callback", h.Callback).Methods("POST")
oauth2Router.HandleFunc("/refresh", h.Refresh).Methods("POST")
oauth2Router.HandleFunc("/userinfo", h.GetUserInfo).Methods("GET")
}
// Authorize initiates the OAuth2 authorization flow
func (h *OAuth2Handler) Authorize(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing OAuth2 authorization request")
var req AuthorizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.logger.Warn("Invalid authorization request", zap.Error(err))
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Generate state if not provided
if req.State == "" {
state, err := h.generateState()
if err != nil {
h.logger.Error("Failed to generate state", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
req.State = state
}
// Generate authorization URL
authURL, err := h.oauth2Provider.GenerateAuthURL(ctx, req.State, req.RedirectURI)
if err != nil {
h.logger.Error("Failed to generate authorization URL", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to generate authorization URL", http.StatusInternalServerError)
return
}
// In production, store the code verifier securely (e.g., in session or cache)
// For now, we'll return it in the response
codeVerifier, err := h.generateCodeVerifier()
if err != nil {
h.logger.Error("Failed to generate code verifier", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
response := AuthorizeResponse{
AuthURL: authURL,
State: req.State,
CodeVerifier: codeVerifier,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("Failed to encode authorization response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Debug("Authorization URL generated successfully",
zap.String("state", req.State),
zap.String("redirect_uri", req.RedirectURI))
}
// Callback handles the OAuth2 callback and exchanges code for tokens
func (h *OAuth2Handler) Callback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing OAuth2 callback")
var req CallbackRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.logger.Warn("Invalid callback request", zap.Error(err))
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Exchange authorization code for tokens
tokenResp, err := h.oauth2Provider.ExchangeCodeForToken(ctx, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
h.logger.Error("Failed to exchange code for token", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to exchange authorization code", http.StatusInternalServerError)
return
}
// Get user information
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
h.logger.Error("Failed to get user info", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
return
}
// Generate internal JWT token for the user
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
if err != nil {
h.logger.Error("Failed to generate internal JWT token", zap.Error(err))
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
return
}
response := CallbackResponse{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
UserInfo: userInfo,
JWTToken: jwtToken,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("Failed to encode callback response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Info("OAuth2 callback processed successfully",
zap.String("user_id", userInfo.Sub),
zap.String("email", userInfo.Email))
}
// Refresh refreshes an access token using refresh token
func (h *OAuth2Handler) Refresh(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing token refresh request")
var req RefreshRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.logger.Warn("Invalid refresh request", zap.Error(err))
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Refresh the access token
tokenResp, err := h.oauth2Provider.RefreshAccessToken(ctx, req.RefreshToken)
if err != nil {
h.logger.Error("Failed to refresh access token", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to refresh access token", http.StatusInternalServerError)
return
}
// Get updated user information
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
h.logger.Error("Failed to get user info during refresh", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
return
}
// Generate new internal JWT token
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
if err != nil {
h.logger.Error("Failed to generate internal JWT token during refresh", zap.Error(err))
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
return
}
response := RefreshResponse{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
JWTToken: jwtToken,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("Failed to encode refresh response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Debug("Token refresh completed successfully",
zap.String("user_id", userInfo.Sub))
}
// GetUserInfo retrieves user information from the current session
func (h *OAuth2Handler) GetUserInfo(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing user info request")
// Extract JWT token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
// Remove "Bearer " prefix
tokenString := authHeader
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
tokenString = authHeader[7:]
}
// Validate JWT token
authContext, err := h.authService.ValidateJWTToken(ctx, tokenString)
if err != nil {
h.logger.Warn("Invalid JWT token in user info request", zap.Error(err))
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
// Return user information from JWT claims
userInfo := map[string]interface{}{
"sub": authContext.UserID,
"email": authContext.Claims["email"],
"name": authContext.Claims["name"],
"permissions": authContext.Permissions,
"app_id": authContext.AppID,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(userInfo); err != nil {
h.logger.Error("Failed to encode user info response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Debug("User info request completed successfully",
zap.String("user_id", authContext.UserID))
}
// generateState generates a random state parameter for OAuth2
func (h *OAuth2Handler) generateState() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
// generateCodeVerifier generates a PKCE code verifier
func (h *OAuth2Handler) generateCodeVerifier() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
// generateInternalJWTToken generates an internal JWT token for authenticated users
func (h *OAuth2Handler) generateInternalJWTToken(ctx context.Context, userInfo *auth.UserInfo) (string, error) {
// Create user token with information from OAuth2 provider
userToken := &domain.UserToken{
AppID: h.config.GetString("INTERNAL_APP_ID"),
UserID: userInfo.Sub,
Permissions: []string{"read", "write"}, // Default permissions, should be based on user roles
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(24 * time.Hour), // 24 hour expiration
MaxValidAt: time.Now().Add(7 * 24 * time.Hour), // 7 days max validity
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"sub": userInfo.Sub,
"email": userInfo.Email,
"name": userInfo.Name,
"email_verified": func() string {
if userInfo.EmailVerified {
return "true"
}
return "false"
}(),
},
}
// Generate JWT token using authentication service
return h.authService.GenerateJWTToken(ctx, userToken)
}