package handlers import ( "encoding/json" "net/http" "time" "github.com/gorilla/mux" "go.uber.org/zap" "github.com/kms/api-key-service/internal/auth" "github.com/kms/api-key-service/internal/config" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/errors" "github.com/kms/api-key-service/internal/services" ) // SAMLHandler handles SAML authentication endpoints type SAMLHandler struct { samlProvider *auth.SAMLProvider sessionService services.SessionService authService services.AuthenticationService tokenService services.TokenService config config.ConfigProvider logger *zap.Logger } // NewSAMLHandler creates a new SAML handler func NewSAMLHandler( config config.ConfigProvider, sessionService services.SessionService, authService services.AuthenticationService, tokenService services.TokenService, logger *zap.Logger, ) (*SAMLHandler, error) { samlProvider, err := auth.NewSAMLProvider(config, logger) if err != nil { return nil, err } return &SAMLHandler{ samlProvider: samlProvider, sessionService: sessionService, authService: authService, config: config, logger: logger, }, nil } // RegisterRoutes registers SAML routes func (h *SAMLHandler) RegisterRoutes(router *mux.Router) { // SAML endpoints router.HandleFunc("/auth/saml/login", h.InitiateSAMLLogin).Methods("GET") router.HandleFunc("/auth/saml/acs", h.HandleSAMLResponse).Methods("POST") router.HandleFunc("/auth/saml/metadata", h.GetServiceProviderMetadata).Methods("GET") router.HandleFunc("/auth/saml/slo", h.HandleSingleLogout).Methods("GET", "POST") } // InitiateSAMLLogin initiates SAML authentication func (h *SAMLHandler) InitiateSAMLLogin(w http.ResponseWriter, r *http.Request) { if !h.config.GetBool("SAML_ENABLED") { h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) return } // Get query parameters appID := r.URL.Query().Get("app_id") redirectURL := r.URL.Query().Get("redirect_url") if appID == "" { h.writeErrorResponse(w, errors.NewValidationError("app_id parameter is required")) return } // Generate relay state with app_id and redirect_url relayState := appID if redirectURL != "" { relayState += "|" + redirectURL } h.logger.Debug("Initiating SAML login", zap.String("app_id", appID), zap.String("redirect_url", redirectURL)) // Generate SAML authentication request authURL, requestID, err := h.samlProvider.GenerateAuthRequest(r.Context(), relayState) if err != nil { h.logger.Error("Failed to generate SAML auth request", zap.Error(err)) h.writeErrorResponse(w, err) return } // Store request ID in session/cache for validation // In production, you should store this securely h.logger.Debug("Generated SAML auth request", zap.String("request_id", requestID), zap.String("auth_url", authURL)) // Redirect to IdP http.Redirect(w, r, authURL, http.StatusFound) } // HandleSAMLResponse handles SAML assertion consumer service (ACS) func (h *SAMLHandler) HandleSAMLResponse(w http.ResponseWriter, r *http.Request) { if !h.config.GetBool("SAML_ENABLED") { h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) return } h.logger.Debug("Handling SAML response") // Parse form data if err := r.ParseForm(); err != nil { h.writeErrorResponse(w, errors.NewValidationError("Failed to parse form data").WithInternal(err)) return } samlResponse := r.FormValue("SAMLResponse") relayState := r.FormValue("RelayState") if samlResponse == "" { h.writeErrorResponse(w, errors.NewValidationError("SAMLResponse is required")) return } h.logger.Debug("Processing SAML response", zap.String("relay_state", relayState)) // Process SAML response // In production, you should retrieve and validate the original request ID authContext, err := h.samlProvider.ProcessSAMLResponse(r.Context(), samlResponse, "") if err != nil { h.logger.Error("Failed to process SAML response", zap.Error(err)) h.writeErrorResponse(w, err) return } // Parse relay state to get app_id and redirect_url appID, redirectURL := h.parseRelayState(relayState) if appID == "" { h.writeErrorResponse(w, errors.NewValidationError("Invalid relay state: missing app_id")) return } // Create user session sessionReq := &domain.CreateSessionRequest{ UserID: authContext.UserID, AppID: appID, SessionType: domain.SessionTypeWeb, IPAddress: h.getClientIP(r), UserAgent: r.UserAgent(), ExpiresAt: time.Now().Add(8 * time.Hour), // 8 hour session Permissions: authContext.Permissions, Claims: authContext.Claims, } session, err := h.sessionService.CreateSession(r.Context(), sessionReq) if err != nil { h.logger.Error("Failed to create session", zap.Error(err)) h.writeErrorResponse(w, err) return } // Generate JWT token for the session using the existing token service userToken := &domain.UserToken{ AppID: appID, UserID: authContext.UserID, Permissions: authContext.Permissions, IssuedAt: time.Now(), ExpiresAt: session.ExpiresAt, MaxValidAt: session.ExpiresAt, TokenType: domain.TokenTypeUser, Claims: authContext.Claims, } tokenString, err := h.authService.GenerateJWTToken(r.Context(), userToken) if err != nil { h.logger.Error("Failed to create JWT token", zap.Error(err)) h.writeErrorResponse(w, err) return } h.logger.Debug("SAML authentication successful", zap.String("user_id", authContext.UserID), zap.String("session_id", session.ID.String())) // If redirect URL is provided, redirect with token if redirectURL != "" { // Add token as query parameter or fragment redirectURL += "?token=" + tokenString http.Redirect(w, r, redirectURL, http.StatusFound) return } // Otherwise, return JSON response response := map[string]interface{}{ "success": true, "token": tokenString, "user": map[string]interface{}{ "id": authContext.UserID, "email": authContext.Claims["email"], "name": authContext.Claims["name"], }, "session_id": session.ID.String(), "expires_at": session.ExpiresAt, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // GetServiceProviderMetadata returns SP metadata XML func (h *SAMLHandler) GetServiceProviderMetadata(w http.ResponseWriter, r *http.Request) { if !h.config.GetBool("SAML_ENABLED") { h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) return } h.logger.Debug("Generating SP metadata") metadata, err := h.samlProvider.GenerateServiceProviderMetadata() if err != nil { h.logger.Error("Failed to generate SP metadata", zap.Error(err)) h.writeErrorResponse(w, err) return } w.Header().Set("Content-Type", "application/xml") w.Write([]byte(metadata)) } // HandleSingleLogout handles SAML single logout func (h *SAMLHandler) HandleSingleLogout(w http.ResponseWriter, r *http.Request) { if !h.config.GetBool("SAML_ENABLED") { h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) return } h.logger.Debug("Handling SAML single logout") // Get session ID from query parameter or form sessionID := r.URL.Query().Get("session_id") if sessionID == "" && r.Method == "POST" { r.ParseForm() sessionID = r.FormValue("session_id") } if sessionID != "" { // Revoke specific session h.logger.Debug("Revoking session", zap.String("session_id", sessionID)) // Implementation would depend on how you store session IDs // For now, we'll just log it } // In a full implementation, you would: // 1. Parse the SAML LogoutRequest // 2. Validate the request // 3. Revoke the user's sessions // 4. Generate a LogoutResponse // 5. Redirect back to the IdP // For now, return a simple success response response := map[string]interface{}{ "success": true, "message": "Logout successful", } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // parseRelayState parses the relay state to extract app_id and redirect_url func (h *SAMLHandler) parseRelayState(relayState string) (appID, redirectURL string) { if relayState == "" { return "", "" } // RelayState format: "app_id|redirect_url" or just "app_id" parts := []string{relayState} if len(relayState) > 0 && relayState[0] != '|' { // Split on first pipe character for i, char := range relayState { if char == '|' { parts = []string{relayState[:i], relayState[i+1:]} break } } } appID = parts[0] if len(parts) > 1 { redirectURL = parts[1] } return appID, redirectURL } // getClientIP extracts the client IP address from the request func (h *SAMLHandler) getClientIP(r *http.Request) string { // Check X-Forwarded-For header first if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP if multiple are present if idx := len(xff); idx > 0 { for i, char := range xff { if char == ',' { return xff[:i] } } return xff } } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to RemoteAddr return r.RemoteAddr } // writeErrorResponse writes an error response func (h *SAMLHandler) writeErrorResponse(w http.ResponseWriter, err error) { var statusCode int var errorCode string switch { case errors.IsValidationError(err): statusCode = http.StatusBadRequest errorCode = "VALIDATION_ERROR" case errors.IsAuthenticationError(err): statusCode = http.StatusUnauthorized errorCode = "AUTHENTICATION_ERROR" case errors.IsConfigurationError(err): statusCode = http.StatusServiceUnavailable errorCode = "CONFIGURATION_ERROR" default: statusCode = http.StatusInternalServerError errorCode = "INTERNAL_ERROR" } response := map[string]interface{}{ "success": false, "error": map[string]interface{}{ "code": errorCode, "message": err.Error(), }, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) json.NewEncoder(w).Encode(response) }