package middleware import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/hex" "net/http" "strconv" "strings" "time" "go.uber.org/zap" "github.com/RyanCopley/skybridge/kms/internal/config" ) // CSRFMiddleware provides CSRF protection type CSRFMiddleware struct { config config.ConfigProvider logger *zap.Logger } // NewCSRFMiddleware creates a new CSRF middleware func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware { return &CSRFMiddleware{ config: config, logger: logger, } } // CSRFProtection implements CSRF protection for state-changing operations func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip CSRF protection for safe methods if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" { next.ServeHTTP(w, r) return } // Skip CSRF protection for specific endpoints that use other authentication if cm.shouldSkipCSRF(r) { next.ServeHTTP(w, r) return } // Get CSRF token from header csrfToken := r.Header.Get("X-CSRF-Token") if csrfToken == "" { cm.logger.Warn("Missing CSRF token", zap.String("path", r.URL.Path), zap.String("method", r.Method), zap.String("remote_addr", r.RemoteAddr)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`)) return } // Validate CSRF token if !cm.validateCSRFToken(csrfToken, r) { cm.logger.Warn("Invalid CSRF token", zap.String("path", r.URL.Path), zap.String("method", r.Method), zap.String("remote_addr", r.RemoteAddr)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`)) return } cm.logger.Debug("CSRF token validated successfully", zap.String("path", r.URL.Path)) next.ServeHTTP(w, r) }) } // GenerateCSRFToken generates a new CSRF token for a user session func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) { // Generate random bytes for token tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { cm.logger.Error("Failed to generate CSRF token", zap.Error(err)) return "", err } // Create timestamp timestamp := time.Now().Unix() // Create token data tokenData := hex.EncodeToString(tokenBytes) // Create signing string: userID:timestamp:tokenData timestampStr := strconv.FormatInt(timestamp, 10) signingString := userID + ":" + timestampStr + ":" + tokenData // Sign the token with HMAC signature := cm.signData(signingString) // Return encoded token: tokenData.timestamp.signature token := tokenData + "." + timestampStr + "." + signature return token, nil } // validateCSRFToken validates a CSRF token func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool { // Parse token parts parts := strings.Split(token, ".") if len(parts) != 3 { cm.logger.Debug("Invalid CSRF token format") return false } tokenData, timestampStr, signature := parts[0], parts[1], parts[2] // Get user ID from request context or headers userID := cm.getUserIDFromRequest(r) if userID == "" { cm.logger.Debug("No user ID found for CSRF validation") return false } // Recreate signing string signingString := userID + ":" + timestampStr + ":" + tokenData // Verify signature expectedSignature := cm.signData(signingString) if !hmac.Equal([]byte(signature), []byte(expectedSignature)) { cm.logger.Debug("CSRF token signature verification failed") return false } // Parse timestamp timestampInt, err := strconv.ParseInt(timestampStr, 10, 64) if err != nil { cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err)) return false } timestamp := time.Unix(timestampInt, 0) // Check if token is expired (valid for 1 hour by default) maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE") if maxAge <= 0 { maxAge = 1 * time.Hour } if time.Since(timestamp) > maxAge { cm.logger.Debug("CSRF token expired", zap.Time("timestamp", timestamp), zap.Duration("age", time.Since(timestamp)), zap.Duration("max_age", maxAge)) return false } return true } // signData signs data with HMAC func (cm *CSRFMiddleware) signData(data string) string { // Use the same signing key as for authentication signingKey := cm.config.GetString("AUTH_SIGNING_KEY") if signingKey == "" { cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection") return "" } mac := hmac.New(sha256.New, []byte(signingKey)) mac.Write([]byte(data)) return hex.EncodeToString(mac.Sum(nil)) } // getUserIDFromRequest extracts user ID from request func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string { // Try to get from X-User-Email header userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL")) if userEmail != "" { return userEmail } // Try to get from context (if set by authentication middleware) if userID := r.Context().Value("user_id"); userID != nil { if id, ok := userID.(string); ok { return id } } return "" } // shouldSkipCSRF determines if CSRF protection should be skipped for a request func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool { // Skip for API endpoints that use API key authentication if strings.HasPrefix(r.URL.Path, "/api/verify") { return true } // Skip for health check endpoints if r.URL.Path == "/health" || r.URL.Path == "/ready" { return true } // Skip for webhook endpoints (if any) if strings.HasPrefix(r.URL.Path, "/webhook/") { return true } return false } // SetCSRFCookie sets a secure CSRF token cookie func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) { cookie := &http.Cookie{ Name: "csrf_token", Value: token, Path: "/", MaxAge: 3600, // 1 hour HttpOnly: false, // JavaScript needs to read this for AJAX requests Secure: true, // HTTPS only SameSite: http.SameSiteStrictMode, } http.SetCookie(w, cookie) } // GetCSRFTokenFromCookie gets CSRF token from cookie func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string { cookie, err := r.Cookie("csrf_token") if err != nil { return "" } return cookie.Value }