353 lines
9.9 KiB
Go
353 lines
9.9 KiB
Go
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/kms/api-key-service/internal/auth"
|
|
"github.com/kms/api-key-service/internal/config"
|
|
"github.com/kms/api-key-service/internal/domain"
|
|
"github.com/kms/api-key-service/internal/errors"
|
|
"github.com/kms/api-key-service/internal/services"
|
|
)
|
|
|
|
// SAMLHandler handles SAML authentication endpoints
|
|
type SAMLHandler struct {
|
|
samlProvider *auth.SAMLProvider
|
|
sessionService services.SessionService
|
|
authService services.AuthenticationService
|
|
tokenService services.TokenService
|
|
config config.ConfigProvider
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewSAMLHandler creates a new SAML handler
|
|
func NewSAMLHandler(
|
|
config config.ConfigProvider,
|
|
sessionService services.SessionService,
|
|
authService services.AuthenticationService,
|
|
tokenService services.TokenService,
|
|
logger *zap.Logger,
|
|
) (*SAMLHandler, error) {
|
|
samlProvider, err := auth.NewSAMLProvider(config, logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &SAMLHandler{
|
|
samlProvider: samlProvider,
|
|
sessionService: sessionService,
|
|
authService: authService,
|
|
config: config,
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
// RegisterRoutes registers SAML routes
|
|
func (h *SAMLHandler) RegisterRoutes(router *mux.Router) {
|
|
// SAML endpoints
|
|
router.HandleFunc("/auth/saml/login", h.InitiateSAMLLogin).Methods("GET")
|
|
router.HandleFunc("/auth/saml/acs", h.HandleSAMLResponse).Methods("POST")
|
|
router.HandleFunc("/auth/saml/metadata", h.GetServiceProviderMetadata).Methods("GET")
|
|
router.HandleFunc("/auth/saml/slo", h.HandleSingleLogout).Methods("GET", "POST")
|
|
}
|
|
|
|
// InitiateSAMLLogin initiates SAML authentication
|
|
func (h *SAMLHandler) InitiateSAMLLogin(w http.ResponseWriter, r *http.Request) {
|
|
if !h.config.GetBool("SAML_ENABLED") {
|
|
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
|
return
|
|
}
|
|
|
|
// Get query parameters
|
|
appID := r.URL.Query().Get("app_id")
|
|
redirectURL := r.URL.Query().Get("redirect_url")
|
|
|
|
if appID == "" {
|
|
h.writeErrorResponse(w, errors.NewValidationError("app_id parameter is required"))
|
|
return
|
|
}
|
|
|
|
// Generate relay state with app_id and redirect_url
|
|
relayState := appID
|
|
if redirectURL != "" {
|
|
relayState += "|" + redirectURL
|
|
}
|
|
|
|
h.logger.Debug("Initiating SAML login",
|
|
zap.String("app_id", appID),
|
|
zap.String("redirect_url", redirectURL))
|
|
|
|
// Generate SAML authentication request
|
|
authURL, requestID, err := h.samlProvider.GenerateAuthRequest(r.Context(), relayState)
|
|
if err != nil {
|
|
h.logger.Error("Failed to generate SAML auth request", zap.Error(err))
|
|
h.writeErrorResponse(w, err)
|
|
return
|
|
}
|
|
|
|
// Store request ID in session/cache for validation
|
|
// In production, you should store this securely
|
|
h.logger.Debug("Generated SAML auth request",
|
|
zap.String("request_id", requestID),
|
|
zap.String("auth_url", authURL))
|
|
|
|
// Redirect to IdP
|
|
http.Redirect(w, r, authURL, http.StatusFound)
|
|
}
|
|
|
|
// HandleSAMLResponse handles SAML assertion consumer service (ACS)
|
|
func (h *SAMLHandler) HandleSAMLResponse(w http.ResponseWriter, r *http.Request) {
|
|
if !h.config.GetBool("SAML_ENABLED") {
|
|
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
|
return
|
|
}
|
|
|
|
h.logger.Debug("Handling SAML response")
|
|
|
|
// Parse form data
|
|
if err := r.ParseForm(); err != nil {
|
|
h.writeErrorResponse(w, errors.NewValidationError("Failed to parse form data").WithInternal(err))
|
|
return
|
|
}
|
|
|
|
samlResponse := r.FormValue("SAMLResponse")
|
|
relayState := r.FormValue("RelayState")
|
|
|
|
if samlResponse == "" {
|
|
h.writeErrorResponse(w, errors.NewValidationError("SAMLResponse is required"))
|
|
return
|
|
}
|
|
|
|
h.logger.Debug("Processing SAML response", zap.String("relay_state", relayState))
|
|
|
|
// Process SAML response
|
|
// In production, you should retrieve and validate the original request ID
|
|
authContext, err := h.samlProvider.ProcessSAMLResponse(r.Context(), samlResponse, "")
|
|
if err != nil {
|
|
h.logger.Error("Failed to process SAML response", zap.Error(err))
|
|
h.writeErrorResponse(w, err)
|
|
return
|
|
}
|
|
|
|
// Parse relay state to get app_id and redirect_url
|
|
appID, redirectURL := h.parseRelayState(relayState)
|
|
if appID == "" {
|
|
h.writeErrorResponse(w, errors.NewValidationError("Invalid relay state: missing app_id"))
|
|
return
|
|
}
|
|
|
|
// Create user session
|
|
sessionReq := &domain.CreateSessionRequest{
|
|
UserID: authContext.UserID,
|
|
AppID: appID,
|
|
SessionType: domain.SessionTypeWeb,
|
|
IPAddress: h.getClientIP(r),
|
|
UserAgent: r.UserAgent(),
|
|
ExpiresAt: time.Now().Add(8 * time.Hour), // 8 hour session
|
|
Permissions: authContext.Permissions,
|
|
Claims: authContext.Claims,
|
|
}
|
|
|
|
session, err := h.sessionService.CreateSession(r.Context(), sessionReq)
|
|
if err != nil {
|
|
h.logger.Error("Failed to create session", zap.Error(err))
|
|
h.writeErrorResponse(w, err)
|
|
return
|
|
}
|
|
|
|
// Generate JWT token for the session using the existing token service
|
|
userToken := &domain.UserToken{
|
|
AppID: appID,
|
|
UserID: authContext.UserID,
|
|
Permissions: authContext.Permissions,
|
|
IssuedAt: time.Now(),
|
|
ExpiresAt: session.ExpiresAt,
|
|
MaxValidAt: session.ExpiresAt,
|
|
TokenType: domain.TokenTypeUser,
|
|
Claims: authContext.Claims,
|
|
}
|
|
|
|
tokenString, err := h.authService.GenerateJWTToken(r.Context(), userToken)
|
|
if err != nil {
|
|
h.logger.Error("Failed to create JWT token", zap.Error(err))
|
|
h.writeErrorResponse(w, err)
|
|
return
|
|
}
|
|
|
|
h.logger.Debug("SAML authentication successful",
|
|
zap.String("user_id", authContext.UserID),
|
|
zap.String("session_id", session.ID.String()))
|
|
|
|
// If redirect URL is provided, redirect with token
|
|
if redirectURL != "" {
|
|
// Add token as query parameter or fragment
|
|
redirectURL += "?token=" + tokenString
|
|
http.Redirect(w, r, redirectURL, http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// Otherwise, return JSON response
|
|
response := map[string]interface{}{
|
|
"success": true,
|
|
"token": tokenString,
|
|
"user": map[string]interface{}{
|
|
"id": authContext.UserID,
|
|
"email": authContext.Claims["email"],
|
|
"name": authContext.Claims["name"],
|
|
},
|
|
"session_id": session.ID.String(),
|
|
"expires_at": session.ExpiresAt,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
// GetServiceProviderMetadata returns SP metadata XML
|
|
func (h *SAMLHandler) GetServiceProviderMetadata(w http.ResponseWriter, r *http.Request) {
|
|
if !h.config.GetBool("SAML_ENABLED") {
|
|
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
|
return
|
|
}
|
|
|
|
h.logger.Debug("Generating SP metadata")
|
|
|
|
metadata, err := h.samlProvider.GenerateServiceProviderMetadata()
|
|
if err != nil {
|
|
h.logger.Error("Failed to generate SP metadata", zap.Error(err))
|
|
h.writeErrorResponse(w, err)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/xml")
|
|
w.Write([]byte(metadata))
|
|
}
|
|
|
|
// HandleSingleLogout handles SAML single logout
|
|
func (h *SAMLHandler) HandleSingleLogout(w http.ResponseWriter, r *http.Request) {
|
|
if !h.config.GetBool("SAML_ENABLED") {
|
|
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
|
return
|
|
}
|
|
|
|
h.logger.Debug("Handling SAML single logout")
|
|
|
|
// Get session ID from query parameter or form
|
|
sessionID := r.URL.Query().Get("session_id")
|
|
if sessionID == "" && r.Method == "POST" {
|
|
r.ParseForm()
|
|
sessionID = r.FormValue("session_id")
|
|
}
|
|
|
|
if sessionID != "" {
|
|
// Revoke specific session
|
|
h.logger.Debug("Revoking session", zap.String("session_id", sessionID))
|
|
// Implementation would depend on how you store session IDs
|
|
// For now, we'll just log it
|
|
}
|
|
|
|
// In a full implementation, you would:
|
|
// 1. Parse the SAML LogoutRequest
|
|
// 2. Validate the request
|
|
// 3. Revoke the user's sessions
|
|
// 4. Generate a LogoutResponse
|
|
// 5. Redirect back to the IdP
|
|
|
|
// For now, return a simple success response
|
|
response := map[string]interface{}{
|
|
"success": true,
|
|
"message": "Logout successful",
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
// parseRelayState parses the relay state to extract app_id and redirect_url
|
|
func (h *SAMLHandler) parseRelayState(relayState string) (appID, redirectURL string) {
|
|
if relayState == "" {
|
|
return "", ""
|
|
}
|
|
|
|
// RelayState format: "app_id|redirect_url" or just "app_id"
|
|
parts := []string{relayState}
|
|
if len(relayState) > 0 && relayState[0] != '|' {
|
|
// Split on first pipe character
|
|
for i, char := range relayState {
|
|
if char == '|' {
|
|
parts = []string{relayState[:i], relayState[i+1:]}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
appID = parts[0]
|
|
if len(parts) > 1 {
|
|
redirectURL = parts[1]
|
|
}
|
|
|
|
return appID, redirectURL
|
|
}
|
|
|
|
// getClientIP extracts the client IP address from the request
|
|
func (h *SAMLHandler) getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header first
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP if multiple are present
|
|
if idx := len(xff); idx > 0 {
|
|
for i, char := range xff {
|
|
if char == ',' {
|
|
return xff[:i]
|
|
}
|
|
}
|
|
return xff
|
|
}
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
// writeErrorResponse writes an error response
|
|
func (h *SAMLHandler) writeErrorResponse(w http.ResponseWriter, err error) {
|
|
var statusCode int
|
|
var errorCode string
|
|
|
|
switch {
|
|
case errors.IsValidationError(err):
|
|
statusCode = http.StatusBadRequest
|
|
errorCode = "VALIDATION_ERROR"
|
|
case errors.IsAuthenticationError(err):
|
|
statusCode = http.StatusUnauthorized
|
|
errorCode = "AUTHENTICATION_ERROR"
|
|
case errors.IsConfigurationError(err):
|
|
statusCode = http.StatusServiceUnavailable
|
|
errorCode = "CONFIGURATION_ERROR"
|
|
default:
|
|
statusCode = http.StatusInternalServerError
|
|
errorCode = "INTERNAL_ERROR"
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"success": false,
|
|
"error": map[string]interface{}{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
},
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|