-
This commit is contained in:
394
internal/handlers/oauth2.go
Normal file
394
internal/handlers/oauth2.go
Normal file
@ -0,0 +1,394 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
)
|
||||
|
||||
// OAuth2Handler handles OAuth2/OIDC authentication flows
|
||||
type OAuth2Handler struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
oauth2Provider *auth.OAuth2Provider
|
||||
authService services.AuthenticationService
|
||||
}
|
||||
|
||||
// NewOAuth2Handler creates a new OAuth2 handler
|
||||
func NewOAuth2Handler(
|
||||
config config.ConfigProvider,
|
||||
logger *zap.Logger,
|
||||
authService services.AuthenticationService,
|
||||
) *OAuth2Handler {
|
||||
oauth2Provider := auth.NewOAuth2Provider(config, logger)
|
||||
|
||||
return &OAuth2Handler{
|
||||
config: config,
|
||||
logger: logger,
|
||||
oauth2Provider: oauth2Provider,
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeRequest represents the OAuth2 authorization request
|
||||
type AuthorizeRequest struct {
|
||||
RedirectURI string `json:"redirect_uri" validate:"required,url"`
|
||||
State string `json:"state,omitempty"`
|
||||
}
|
||||
|
||||
// AuthorizeResponse represents the OAuth2 authorization response
|
||||
type AuthorizeResponse struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"` // In production, this should be stored securely
|
||||
}
|
||||
|
||||
// CallbackRequest represents the OAuth2 callback request
|
||||
type CallbackRequest struct {
|
||||
Code string `json:"code" validate:"required"`
|
||||
State string `json:"state,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri" validate:"required,url"`
|
||||
CodeVerifier string `json:"code_verifier" validate:"required"`
|
||||
}
|
||||
|
||||
// CallbackResponse represents the OAuth2 callback response
|
||||
type CallbackResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
UserInfo *auth.UserInfo `json:"user_info"`
|
||||
JWTToken string `json:"jwt_token"`
|
||||
}
|
||||
|
||||
// RefreshRequest represents the token refresh request
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token" validate:"required"`
|
||||
}
|
||||
|
||||
// RefreshResponse represents the token refresh response
|
||||
type RefreshResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
JWTToken string `json:"jwt_token"`
|
||||
}
|
||||
|
||||
// RegisterRoutes registers OAuth2 routes
|
||||
func (h *OAuth2Handler) RegisterRoutes(router *mux.Router) {
|
||||
oauth2Router := router.PathPrefix("/oauth2").Subrouter()
|
||||
|
||||
oauth2Router.HandleFunc("/authorize", h.Authorize).Methods("POST")
|
||||
oauth2Router.HandleFunc("/callback", h.Callback).Methods("POST")
|
||||
oauth2Router.HandleFunc("/refresh", h.Refresh).Methods("POST")
|
||||
oauth2Router.HandleFunc("/userinfo", h.GetUserInfo).Methods("GET")
|
||||
}
|
||||
|
||||
// Authorize initiates the OAuth2 authorization flow
|
||||
func (h *OAuth2Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing OAuth2 authorization request")
|
||||
|
||||
var req AuthorizeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.logger.Warn("Invalid authorization request", zap.Error(err))
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate state if not provided
|
||||
if req.State == "" {
|
||||
state, err := h.generateState()
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate state", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.State = state
|
||||
}
|
||||
|
||||
// Generate authorization URL
|
||||
authURL, err := h.oauth2Provider.GenerateAuthURL(ctx, req.State, req.RedirectURI)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate authorization URL", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to generate authorization URL", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// In production, store the code verifier securely (e.g., in session or cache)
|
||||
// For now, we'll return it in the response
|
||||
codeVerifier, err := h.generateCodeVerifier()
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate code verifier", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := AuthorizeResponse{
|
||||
AuthURL: authURL,
|
||||
State: req.State,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Error("Failed to encode authorization response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Authorization URL generated successfully",
|
||||
zap.String("state", req.State),
|
||||
zap.String("redirect_uri", req.RedirectURI))
|
||||
}
|
||||
|
||||
// Callback handles the OAuth2 callback and exchanges code for tokens
|
||||
func (h *OAuth2Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing OAuth2 callback")
|
||||
|
||||
var req CallbackRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.logger.Warn("Invalid callback request", zap.Error(err))
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange authorization code for tokens
|
||||
tokenResp, err := h.oauth2Provider.ExchangeCodeForToken(ctx, req.Code, req.RedirectURI, req.CodeVerifier)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to exchange code for token", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to exchange authorization code", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user information
|
||||
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user info", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate internal JWT token for the user
|
||||
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate internal JWT token", zap.Error(err))
|
||||
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := CallbackResponse{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
UserInfo: userInfo,
|
||||
JWTToken: jwtToken,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Error("Failed to encode callback response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("OAuth2 callback processed successfully",
|
||||
zap.String("user_id", userInfo.Sub),
|
||||
zap.String("email", userInfo.Email))
|
||||
}
|
||||
|
||||
// Refresh refreshes an access token using refresh token
|
||||
func (h *OAuth2Handler) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing token refresh request")
|
||||
|
||||
var req RefreshRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.logger.Warn("Invalid refresh request", zap.Error(err))
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh the access token
|
||||
tokenResp, err := h.oauth2Provider.RefreshAccessToken(ctx, req.RefreshToken)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to refresh access token", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to refresh access token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get updated user information
|
||||
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user info during refresh", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate new internal JWT token
|
||||
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate internal JWT token during refresh", zap.Error(err))
|
||||
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := RefreshResponse{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
JWTToken: jwtToken,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Error("Failed to encode refresh response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Token refresh completed successfully",
|
||||
zap.String("user_id", userInfo.Sub))
|
||||
}
|
||||
|
||||
// GetUserInfo retrieves user information from the current session
|
||||
func (h *OAuth2Handler) GetUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing user info request")
|
||||
|
||||
// Extract JWT token from Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove "Bearer " prefix
|
||||
tokenString := authHeader
|
||||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
tokenString = authHeader[7:]
|
||||
}
|
||||
|
||||
// Validate JWT token
|
||||
authContext, err := h.authService.ValidateJWTToken(ctx, tokenString)
|
||||
if err != nil {
|
||||
h.logger.Warn("Invalid JWT token in user info request", zap.Error(err))
|
||||
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Return user information from JWT claims
|
||||
userInfo := map[string]interface{}{
|
||||
"sub": authContext.UserID,
|
||||
"email": authContext.Claims["email"],
|
||||
"name": authContext.Claims["name"],
|
||||
"permissions": authContext.Permissions,
|
||||
"app_id": authContext.AppID,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(userInfo); err != nil {
|
||||
h.logger.Error("Failed to encode user info response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("User info request completed successfully",
|
||||
zap.String("user_id", authContext.UserID))
|
||||
}
|
||||
|
||||
// generateState generates a random state parameter for OAuth2
|
||||
func (h *OAuth2Handler) generateState() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a PKCE code verifier
|
||||
func (h *OAuth2Handler) generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateInternalJWTToken generates an internal JWT token for authenticated users
|
||||
func (h *OAuth2Handler) generateInternalJWTToken(ctx context.Context, userInfo *auth.UserInfo) (string, error) {
|
||||
// Create user token with information from OAuth2 provider
|
||||
userToken := &domain.UserToken{
|
||||
AppID: h.config.GetString("INTERNAL_APP_ID"),
|
||||
UserID: userInfo.Sub,
|
||||
Permissions: []string{"read", "write"}, // Default permissions, should be based on user roles
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour), // 24 hour expiration
|
||||
MaxValidAt: time.Now().Add(7 * 24 * time.Hour), // 7 days max validity
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"sub": userInfo.Sub,
|
||||
"email": userInfo.Email,
|
||||
"name": userInfo.Name,
|
||||
"email_verified": func() string {
|
||||
if userInfo.EmailVerified {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
// Generate JWT token using authentication service
|
||||
return h.authService.GenerateJWTToken(ctx, userToken)
|
||||
}
|
||||
Reference in New Issue
Block a user