This commit is contained in:
2025-08-23 22:31:47 -04:00
parent 9ca9c53baf
commit e5bccc85c2
22 changed files with 2405 additions and 209 deletions

View File

@ -0,0 +1,171 @@
package auth
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/errors"
)
// HeaderValidator provides secure validation of authentication headers
type HeaderValidator struct {
config config.ConfigProvider
logger *zap.Logger
}
// NewHeaderValidator creates a new header validator
func NewHeaderValidator(config config.ConfigProvider, logger *zap.Logger) *HeaderValidator {
return &HeaderValidator{
config: config,
logger: logger,
}
}
// ValidatedUserContext holds validated user information
type ValidatedUserContext struct {
UserID string
Email string
Timestamp time.Time
Signature string
}
// ValidateAuthenticationHeaders validates user authentication headers with HMAC signature
func (hv *HeaderValidator) ValidateAuthenticationHeaders(r *http.Request) (*ValidatedUserContext, error) {
userEmail := r.Header.Get(hv.config.GetString("AUTH_HEADER_USER_EMAIL"))
timestamp := r.Header.Get("X-Auth-Timestamp")
signature := r.Header.Get("X-Auth-Signature")
if userEmail == "" {
hv.logger.Warn("Missing user email header")
return nil, errors.NewAuthenticationError("User authentication required")
}
if timestamp == "" || signature == "" {
hv.logger.Warn("Missing authentication signature headers",
zap.String("user_email", userEmail))
return nil, errors.NewAuthenticationError("Authentication signature required")
}
// Validate timestamp (prevent replay attacks)
timestampInt, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
hv.logger.Warn("Invalid timestamp format",
zap.String("timestamp", timestamp),
zap.String("user_email", userEmail))
return nil, errors.NewAuthenticationError("Invalid timestamp format")
}
timestampTime := time.Unix(timestampInt, 0)
now := time.Now()
// Allow 5 minutes clock skew
maxAge := 5 * time.Minute
if now.Sub(timestampTime) > maxAge || timestampTime.After(now.Add(1*time.Minute)) {
hv.logger.Warn("Timestamp outside acceptable window",
zap.Time("timestamp", timestampTime),
zap.Time("now", now),
zap.String("user_email", userEmail))
return nil, errors.NewAuthenticationError("Request timestamp outside acceptable window")
}
// Validate HMAC signature
if !hv.validateSignature(userEmail, timestamp, signature) {
hv.logger.Warn("Invalid authentication signature",
zap.String("user_email", userEmail))
return nil, errors.NewAuthenticationError("Invalid authentication signature")
}
// Validate email format
if !hv.isValidEmail(userEmail) {
hv.logger.Warn("Invalid email format",
zap.String("user_email", userEmail))
return nil, errors.NewAuthenticationError("Invalid email format")
}
hv.logger.Debug("Authentication headers validated successfully",
zap.String("user_email", userEmail))
return &ValidatedUserContext{
UserID: userEmail,
Email: userEmail,
Timestamp: timestampTime,
Signature: signature,
}, nil
}
// validateSignature validates the HMAC signature
func (hv *HeaderValidator) validateSignature(userEmail, timestamp, signature string) bool {
// Get the signing key from config
signingKey := hv.config.GetString("AUTH_SIGNING_KEY")
if signingKey == "" {
hv.logger.Error("AUTH_SIGNING_KEY not configured")
return false
}
// Create the signing string
signingString := fmt.Sprintf("%s:%s", userEmail, timestamp)
// Calculate expected signature
mac := hmac.New(sha256.New, []byte(signingKey))
mac.Write([]byte(signingString))
expectedSignature := hex.EncodeToString(mac.Sum(nil))
// Use constant-time comparison to prevent timing attacks
return hmac.Equal([]byte(signature), []byte(expectedSignature))
}
// isValidEmail performs basic email validation
func (hv *HeaderValidator) isValidEmail(email string) bool {
if len(email) == 0 || len(email) > 254 {
return false
}
// Basic email validation - contains @ and has valid structure
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
local, domain := parts[0], parts[1]
// Local part validation
if len(local) == 0 || len(local) > 64 {
return false
}
// Domain part validation
if len(domain) == 0 || len(domain) > 253 {
return false
}
if !strings.Contains(domain, ".") {
return false
}
// Check for invalid characters (basic check)
invalidChars := []string{" ", "..", "@@", "<", ">", "\"", "'"}
for _, char := range invalidChars {
if strings.Contains(email, char) {
return false
}
}
return true
}
// GenerateSignatureExample generates an example signature for documentation
func (hv *HeaderValidator) GenerateSignatureExample(userEmail string, timestamp string, signingKey string) string {
signingString := fmt.Sprintf("%s:%s", userEmail, timestamp)
mac := hmac.New(sha256.New, []byte(signingKey))
mac.Write([]byte(signingString))
return hex.EncodeToString(mac.Sum(nil))
}

View File

@ -57,6 +57,12 @@ func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error)
return "", errors.NewValidationError("JWT secret not configured")
}
// Generate secure JWT ID
jti := j.generateJTI()
if jti == "" {
return "", errors.NewInternalError("Failed to generate secure JWT ID - cryptographic random number generation failed")
}
// Create custom claims
claims := CustomClaims{
UserID: userToken.UserID,
@ -72,7 +78,7 @@ func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error)
ExpiresAt: jwt.NewNumericDate(userToken.ExpiresAt),
IssuedAt: jwt.NewNumericDate(userToken.IssuedAt),
NotBefore: jwt.NewNumericDate(userToken.IssuedAt),
ID: j.generateJTI(),
ID: jti,
},
}
@ -272,8 +278,10 @@ func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) {
func (j *JWTManager) generateJTI() string {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
// Fallback to timestamp-based ID if random generation fails
return fmt.Sprintf("jti_%d", time.Now().UnixNano())
// Log the error and fail securely - do not generate predictable fallback IDs
j.logger.Error("Cryptographic random number generation failed - cannot generate secure JWT ID", zap.Error(err))
// Return an error indicator that will cause token generation to fail
return ""
}
return base64.URLEncoding.EncodeToString(bytes)
}

View File

@ -0,0 +1,353 @@
package authorization
import (
"context"
"fmt"
"strings"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
)
// ResourceType represents different types of resources
type ResourceType string
const (
ResourceTypeApplication ResourceType = "application"
ResourceTypeToken ResourceType = "token"
ResourceTypePermission ResourceType = "permission"
ResourceTypeUser ResourceType = "user"
)
// Action represents different actions that can be performed
type Action string
const (
ActionRead Action = "read"
ActionWrite Action = "write"
ActionDelete Action = "delete"
ActionCreate Action = "create"
)
// AuthorizationContext holds context for authorization decisions
type AuthorizationContext struct {
UserID string
UserEmail string
ResourceType ResourceType
ResourceID string
Action Action
OwnerInfo *domain.Owner
}
// AuthorizationService provides role-based access control
type AuthorizationService struct {
logger *zap.Logger
}
// NewAuthorizationService creates a new authorization service
func NewAuthorizationService(logger *zap.Logger) *AuthorizationService {
return &AuthorizationService{
logger: logger,
}
}
// AuthorizeResourceAccess checks if a user can perform an action on a resource
func (a *AuthorizationService) AuthorizeResourceAccess(ctx context.Context, authCtx *AuthorizationContext) error {
if authCtx == nil {
return errors.NewForbiddenError("Authorization context is required")
}
a.logger.Debug("Authorizing resource access",
zap.String("user_id", authCtx.UserID),
zap.String("resource_type", string(authCtx.ResourceType)),
zap.String("resource_id", authCtx.ResourceID),
zap.String("action", string(authCtx.Action)))
// Check if user is a system admin
if a.isSystemAdmin(authCtx.UserID) {
a.logger.Debug("System admin access granted", zap.String("user_id", authCtx.UserID))
return nil
}
// Check resource ownership
if authCtx.OwnerInfo != nil {
if a.isResourceOwner(authCtx, authCtx.OwnerInfo) {
a.logger.Debug("Resource owner access granted",
zap.String("user_id", authCtx.UserID),
zap.String("resource_id", authCtx.ResourceID))
return nil
}
}
// Check specific resource-action combinations
switch authCtx.ResourceType {
case ResourceTypeApplication:
return a.authorizeApplicationAccess(authCtx)
case ResourceTypeToken:
return a.authorizeTokenAccess(authCtx)
case ResourceTypePermission:
return a.authorizePermissionAccess(authCtx)
case ResourceTypeUser:
return a.authorizeUserAccess(authCtx)
default:
return errors.NewForbiddenError(fmt.Sprintf("Unknown resource type: %s", authCtx.ResourceType))
}
}
// AuthorizeApplicationOwnership checks if a user owns an application
func (a *AuthorizationService) AuthorizeApplicationOwnership(userID string, app *domain.Application) error {
if app == nil {
return errors.NewValidationError("Application is required")
}
// System admins can access any application
if a.isSystemAdmin(userID) {
return nil
}
// Check if user is the owner
if a.isOwner(userID, &app.Owner) {
return nil
}
a.logger.Warn("Application ownership authorization failed",
zap.String("user_id", userID),
zap.String("app_id", app.AppID),
zap.String("owner_type", string(app.Owner.Type)),
zap.String("owner_name", app.Owner.Name))
return errors.NewForbiddenError("You do not have permission to access this application")
}
// AuthorizeTokenOwnership checks if a user owns a token
func (a *AuthorizationService) AuthorizeTokenOwnership(userID string, token interface{}) error {
// System admins can access any token
if a.isSystemAdmin(userID) {
return nil
}
// Extract owner information based on token type
var owner *domain.Owner
var tokenID string
switch t := token.(type) {
case *domain.StaticToken:
owner = &t.Owner
tokenID = t.ID.String()
case *domain.UserToken:
// For user tokens, the user ID should match
if t.UserID == userID {
return nil
}
tokenID = "user_token"
default:
return errors.NewValidationError("Unknown token type")
}
// Check ownership
if owner != nil && a.isOwner(userID, owner) {
return nil
}
a.logger.Warn("Token ownership authorization failed",
zap.String("user_id", userID),
zap.String("token_id", tokenID))
return errors.NewForbiddenError("You do not have permission to access this token")
}
// isSystemAdmin checks if a user is a system administrator
func (a *AuthorizationService) isSystemAdmin(userID string) bool {
// System admin users - this should be configurable
systemAdmins := []string{
"admin@example.com",
"system@internal.com",
}
for _, admin := range systemAdmins {
if userID == admin {
return true
}
}
return false
}
// isResourceOwner checks if the user is the owner of a resource
func (a *AuthorizationService) isResourceOwner(authCtx *AuthorizationContext, owner *domain.Owner) bool {
return a.isOwner(authCtx.UserID, owner)
}
// isOwner checks if a user is the owner based on owner information
func (a *AuthorizationService) isOwner(userID string, owner *domain.Owner) bool {
switch owner.Type {
case domain.OwnerTypeIndividual:
// For individual ownership, check if the user ID matches the owner name
return userID == owner.Name || userID == owner.Owner
case domain.OwnerTypeTeam:
// For team ownership, this would typically require a team membership check
// For now, we'll check if the user is the team owner
return userID == owner.Owner || a.isTeamMember(userID, owner.Name)
default:
return false
}
}
// isTeamMember checks if a user is a member of a team (placeholder implementation)
func (a *AuthorizationService) isTeamMember(userID, teamName string) bool {
// In a real implementation, this would check team membership in a database
// For now, we'll use a simple heuristic based on email domains
if !strings.Contains(userID, "@") {
return false
}
userDomain := strings.Split(userID, "@")[1]
teamDomain := strings.ToLower(teamName)
// Simple check: if team name looks like a domain and user's domain matches
if strings.Contains(teamDomain, ".") && strings.Contains(userDomain, teamDomain) {
return true
}
// Additional team membership logic would go here
return false
}
// authorizeApplicationAccess handles application-specific authorization
func (a *AuthorizationService) authorizeApplicationAccess(authCtx *AuthorizationContext) error {
switch authCtx.Action {
case ActionRead:
// Users can read applications they have some relationship with
// This could be expanded to check for shared access, etc.
return errors.NewForbiddenError("You do not have permission to read this application")
case ActionWrite:
// Only owners can modify applications
return errors.NewForbiddenError("You do not have permission to modify this application")
case ActionDelete:
// Only owners can delete applications
return errors.NewForbiddenError("You do not have permission to delete this application")
case ActionCreate:
// Most users can create applications (with rate limiting)
return nil
default:
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
}
}
// authorizeTokenAccess handles token-specific authorization
func (a *AuthorizationService) authorizeTokenAccess(authCtx *AuthorizationContext) error {
switch authCtx.Action {
case ActionRead:
return errors.NewForbiddenError("You do not have permission to read this token")
case ActionWrite:
return errors.NewForbiddenError("You do not have permission to modify this token")
case ActionDelete:
return errors.NewForbiddenError("You do not have permission to delete this token")
case ActionCreate:
return errors.NewForbiddenError("You do not have permission to create tokens for this application")
default:
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
}
}
// authorizePermissionAccess handles permission-specific authorization
func (a *AuthorizationService) authorizePermissionAccess(authCtx *AuthorizationContext) error {
switch authCtx.Action {
case ActionRead:
// Users can read permissions they have
return nil
case ActionWrite:
// Only admins can modify permissions
return errors.NewForbiddenError("You do not have permission to modify permissions")
case ActionDelete:
// Only admins can delete permissions
return errors.NewForbiddenError("You do not have permission to delete permissions")
case ActionCreate:
// Only admins can create permissions
return errors.NewForbiddenError("You do not have permission to create permissions")
default:
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
}
}
// authorizeUserAccess handles user-specific authorization
func (a *AuthorizationService) authorizeUserAccess(authCtx *AuthorizationContext) error {
switch authCtx.Action {
case ActionRead:
// Users can read their own information
if authCtx.ResourceID == authCtx.UserID {
return nil
}
return errors.NewForbiddenError("You do not have permission to read this user's information")
case ActionWrite:
// Users can modify their own information
if authCtx.ResourceID == authCtx.UserID {
return nil
}
return errors.NewForbiddenError("You do not have permission to modify this user's information")
case ActionDelete:
// Users can delete their own account, admins can delete any
if authCtx.ResourceID == authCtx.UserID {
return nil
}
return errors.NewForbiddenError("You do not have permission to delete this user")
default:
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
}
}
// AuthorizeListAccess checks if a user can list resources of a specific type
func (a *AuthorizationService) AuthorizeListAccess(ctx context.Context, userID string, resourceType ResourceType) error {
a.logger.Debug("Authorizing list access",
zap.String("user_id", userID),
zap.String("resource_type", string(resourceType)))
// System admins can list anything
if a.isSystemAdmin(userID) {
return nil
}
// For now, allow users to list their own resources
// This would be refined based on business requirements
switch resourceType {
case ResourceTypeApplication:
return nil // Users can list applications (filtered by ownership)
case ResourceTypeToken:
return nil // Users can list their own tokens
case ResourceTypePermission:
return nil // Users can list available permissions
case ResourceTypeUser:
// Only admins can list users
return errors.NewForbiddenError("You do not have permission to list users")
default:
return errors.NewForbiddenError(fmt.Sprintf("Unknown resource type: %s", resourceType))
}
}
// GetUserResourceFilter returns a filter for resources that a user can access
func (a *AuthorizationService) GetUserResourceFilter(userID string, resourceType ResourceType) map[string]interface{} {
filter := make(map[string]interface{})
// System admins see everything
if a.isSystemAdmin(userID) {
return filter // Empty filter means no restrictions
}
// Filter by ownership
switch resourceType {
case ResourceTypeApplication, ResourceTypeToken:
// Users can only see resources they own
filter["owner_email"] = userID
case ResourceTypePermission:
// Users can see all permissions (they're not user-specific)
return filter
case ResourceTypeUser:
// Users can only see themselves
filter["user_id"] = userID
}
return filter
}

View File

@ -36,6 +36,9 @@ type ConfigProvider interface {
// GetDatabaseDSN constructs and returns the database connection string
GetDatabaseDSN() string
// GetDatabaseDSNForLogging returns a sanitized database connection string safe for logging
GetDatabaseDSNForLogging() string
// GetServerAddress returns the server address in host:port format
GetServerAddress() string
@ -104,17 +107,20 @@ func (c *Config) setDefaults() {
"RATE_LIMIT_ENABLED": "true",
"RATE_LIMIT_RPS": "100",
"RATE_LIMIT_BURST": "200",
"AUTH_RATE_LIMIT_RPS": "5",
"AUTH_RATE_LIMIT_BURST": "10",
"CACHE_ENABLED": "false",
"CACHE_TTL": "1h",
"JWT_ISSUER": "api-key-service",
"JWT_SECRET": "bootstrap-jwt-secret-change-in-production",
"JWT_SECRET": "", // Must be set via environment variable
"AUTH_PROVIDER": "header", // header or sso
"AUTH_HEADER_USER_EMAIL": "X-User-Email",
"AUTH_SIGNING_KEY": "", // Must be set via environment variable
"SSO_PROVIDER_URL": "",
"SSO_CLIENT_ID": "",
"SSO_CLIENT_SECRET": "",
"INTERNAL_APP_ID": "internal.api-key-service",
"INTERNAL_HMAC_KEY": "bootstrap-hmac-key-change-in-production",
"INTERNAL_HMAC_KEY": "", // Must be set via environment variable
"METRICS_ENABLED": "false",
"METRICS_PORT": "9090",
"REDIS_ENABLED": "false",
@ -131,6 +137,8 @@ func (c *Config) setDefaults() {
"AUTH_FAILURE_WINDOW": "15m",
"IP_BLOCK_DURATION": "1h",
"REQUEST_MAX_AGE": "5m",
"CSRF_TOKEN_MAX_AGE": "1h",
"BCRYPT_COST": "14",
"IP_WHITELIST": "",
"SAML_ENABLED": "false",
"SAML_IDP_METADATA_URL": "",
@ -212,6 +220,7 @@ func (c *Config) Validate() error {
"INTERNAL_APP_ID",
"INTERNAL_HMAC_KEY",
"JWT_SECRET",
"AUTH_SIGNING_KEY",
}
var missing []string
@ -225,6 +234,22 @@ func (c *Config) Validate() error {
return fmt.Errorf("missing required configuration keys: %s", strings.Join(missing, ", "))
}
// Validate that production secrets are not using default values
jwtSecret := c.GetString("JWT_SECRET")
if jwtSecret == "bootstrap-jwt-secret-change-in-production" || len(jwtSecret) < 32 {
return fmt.Errorf("JWT_SECRET must be set to a secure value (minimum 32 characters)")
}
hmacKey := c.GetString("INTERNAL_HMAC_KEY")
if hmacKey == "bootstrap-hmac-key-change-in-production" || len(hmacKey) < 32 {
return fmt.Errorf("INTERNAL_HMAC_KEY must be set to a secure value (minimum 32 characters)")
}
authSigningKey := c.GetString("AUTH_SIGNING_KEY")
if len(authSigningKey) < 32 {
return fmt.Errorf("AUTH_SIGNING_KEY must be set to a secure value (minimum 32 characters)")
}
// Validate specific values
if c.GetInt("DB_PORT") <= 0 || c.GetInt("DB_PORT") > 65535 {
return fmt.Errorf("DB_PORT must be a valid port number")
@ -278,6 +303,27 @@ func (c *Config) GetDatabaseDSN() string {
)
}
// GetDatabaseDSNForLogging returns a sanitized database connection string safe for logging
func (c *Config) GetDatabaseDSNForLogging() string {
password := c.GetString("DB_PASSWORD")
maskedPassword := "***MASKED***"
if len(password) > 0 {
// Show first and last character with masking for debugging
if len(password) >= 4 {
maskedPassword = string(password[0]) + "***" + string(password[len(password)-1])
}
}
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
c.GetString("DB_HOST"),
c.GetInt("DB_PORT"),
c.GetString("DB_USER"),
maskedPassword,
c.GetString("DB_NAME"),
c.GetString("DB_SSLMODE"),
)
}
// GetServerAddress returns the server address in host:port format
func (c *Config) GetServerAddress() string {
return fmt.Sprintf("%s:%d", c.GetString("SERVER_HOST"), c.GetInt("SERVER_PORT"))

View File

@ -18,17 +18,42 @@ const (
TokenLength = 32
// TokenPrefix is prepended to all tokens for identification
TokenPrefix = "kms_"
// BcryptCost defines the bcrypt cost for 2025 security standards (minimum 14)
BcryptCost = 14
)
// TokenGenerator provides secure token generation and validation
type TokenGenerator struct {
hmacKey []byte
hmacKey []byte
bcryptCost int
}
// NewTokenGenerator creates a new token generator with the provided HMAC key
func NewTokenGenerator(hmacKey string) *TokenGenerator {
return &TokenGenerator{
hmacKey: []byte(hmacKey),
hmacKey: []byte(hmacKey),
bcryptCost: BcryptCost,
}
}
// NewTokenGeneratorWithCost creates a new token generator with custom bcrypt cost
func NewTokenGeneratorWithCost(hmacKey string, bcryptCost int) *TokenGenerator {
// Validate bcrypt cost (must be between 4 and 31)
if bcryptCost < 4 {
bcryptCost = 4
} else if bcryptCost > 31 {
bcryptCost = 31
}
// Warn if cost is too low for production
if bcryptCost < 12 {
// This should log a warning, but we don't have logger here
// In a real implementation, you'd pass a logger or use a global one
}
return &TokenGenerator{
hmacKey: []byte(hmacKey),
bcryptCost: bcryptCost,
}
}
@ -69,10 +94,10 @@ func (tg *TokenGenerator) GenerateSecureTokenWithPrefix(appPrefix string, tokenT
// HashToken creates a secure hash of the token for storage
func (tg *TokenGenerator) HashToken(token string) (string, error) {
// Use bcrypt for secure password-like hashing
hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
// Use bcrypt with configured cost
hash, err := bcrypt.GenerateFromPassword([]byte(token), tg.bcryptCost)
if err != nil {
return "", fmt.Errorf("failed to hash token: %w", err)
return "", fmt.Errorf("failed to hash token with bcrypt cost %d: %w", tg.bcryptCost, err)
}
return string(hash), nil

View File

@ -4,12 +4,8 @@ import (
"context"
"database/sql"
"fmt"
"path/filepath"
"time"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
_ "github.com/lib/pq"
"github.com/kms/api-key-service/internal/repository"
@ -17,7 +13,8 @@ import (
// PostgresProvider implements the DatabaseProvider interface
type PostgresProvider struct {
db *sql.DB
db *sql.DB
dsn string
}
// NewPostgresProvider creates a new PostgreSQL database provider
@ -44,7 +41,7 @@ func NewPostgresProvider(dsn string, maxOpenConns, maxIdleConns int, maxLifetime
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &PostgresProvider{db: db}, nil
return &PostgresProvider{db: db, dsn: dsn}, nil
}
// GetDB returns the underlying database connection
@ -81,51 +78,7 @@ func (p *PostgresProvider) BeginTx(ctx context.Context) (repository.TransactionP
return &PostgresTransaction{tx: tx}, nil
}
// Migrate runs database migrations
func (p *PostgresProvider) Migrate(ctx context.Context, migrationPath string) error {
// Create a separate connection for migrations to avoid interfering with the main connection
migrationDB, err := sql.Open("postgres", p.getDSN())
if err != nil {
return fmt.Errorf("failed to open migration database connection: %w", err)
}
defer migrationDB.Close()
driver, err := postgres.WithInstance(migrationDB, &postgres.Config{})
if err != nil {
return fmt.Errorf("failed to create postgres driver: %w", err)
}
// Convert relative path to file URL
absPath, err := filepath.Abs(migrationPath)
if err != nil {
return fmt.Errorf("failed to get absolute path: %w", err)
}
m, err := migrate.NewWithDatabaseInstance(
fmt.Sprintf("file://%s", absPath),
"postgres",
driver,
)
if err != nil {
return fmt.Errorf("failed to create migrate instance: %w", err)
}
defer m.Close()
// Run migrations
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to run migrations: %w", err)
}
return nil
}
// getDSN reconstructs the DSN from the current connection
// This is a workaround since we don't store the original DSN
func (p *PostgresProvider) getDSN() string {
// For now, we'll use the default values from config
// In a production system, we'd store the original DSN
return "host=localhost port=5432 user=postgres password=postgres dbname=kms sslmode=disable"
}
// PostgresTransaction implements the TransactionProvider interface
type PostgresTransaction struct {

View File

@ -0,0 +1,245 @@
package errors
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// SecureErrorResponse represents a sanitized error response for clients
type SecureErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
RequestID string `json:"request_id,omitempty"`
Code int `json:"code"`
}
// ErrorHandler provides secure error handling for HTTP responses
type ErrorHandler struct {
logger *zap.Logger
}
// NewErrorHandler creates a new secure error handler
func NewErrorHandler(logger *zap.Logger) *ErrorHandler {
return &ErrorHandler{
logger: logger,
}
}
// HandleError handles errors securely by logging detailed information and returning sanitized responses
func (eh *ErrorHandler) HandleError(c *gin.Context, err error, userMessage string) {
requestID := eh.getOrGenerateRequestID(c)
// Log detailed error information for internal debugging
eh.logger.Error("HTTP request error",
zap.String("request_id", requestID),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
zap.String("user_agent", c.Request.UserAgent()),
zap.String("remote_addr", c.ClientIP()),
zap.Error(err),
)
// Determine appropriate HTTP status code and error type
statusCode, errorType := eh.determineErrorResponse(err)
// Create sanitized response
response := SecureErrorResponse{
Error: errorType,
Message: eh.sanitizeErrorMessage(userMessage, err),
RequestID: requestID,
Code: statusCode,
}
c.JSON(statusCode, response)
}
// HandleValidationError handles input validation errors
func (eh *ErrorHandler) HandleValidationError(c *gin.Context, field string, message string) {
requestID := eh.getOrGenerateRequestID(c)
eh.logger.Warn("Validation error",
zap.String("request_id", requestID),
zap.String("field", field),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
response := SecureErrorResponse{
Error: "validation_error",
Message: "Invalid input provided",
RequestID: requestID,
Code: http.StatusBadRequest,
}
c.JSON(http.StatusBadRequest, response)
}
// HandleAuthenticationError handles authentication failures
func (eh *ErrorHandler) HandleAuthenticationError(c *gin.Context, err error) {
requestID := eh.getOrGenerateRequestID(c)
eh.logger.Warn("Authentication error",
zap.String("request_id", requestID),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
zap.String("remote_addr", c.ClientIP()),
zap.Error(err),
)
response := SecureErrorResponse{
Error: "authentication_failed",
Message: "Authentication required",
RequestID: requestID,
Code: http.StatusUnauthorized,
}
c.JSON(http.StatusUnauthorized, response)
}
// HandleAuthorizationError handles authorization failures
func (eh *ErrorHandler) HandleAuthorizationError(c *gin.Context, resource string) {
requestID := eh.getOrGenerateRequestID(c)
eh.logger.Warn("Authorization error",
zap.String("request_id", requestID),
zap.String("resource", resource),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
zap.String("remote_addr", c.ClientIP()),
)
response := SecureErrorResponse{
Error: "access_denied",
Message: "Insufficient permissions",
RequestID: requestID,
Code: http.StatusForbidden,
}
c.JSON(http.StatusForbidden, response)
}
// HandleInternalError handles internal server errors
func (eh *ErrorHandler) HandleInternalError(c *gin.Context, err error) {
requestID := eh.getOrGenerateRequestID(c)
eh.logger.Error("Internal server error",
zap.String("request_id", requestID),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
zap.String("remote_addr", c.ClientIP()),
zap.Error(err),
)
response := SecureErrorResponse{
Error: "internal_error",
Message: "An internal error occurred",
RequestID: requestID,
Code: http.StatusInternalServerError,
}
c.JSON(http.StatusInternalServerError, response)
}
// determineErrorResponse determines the appropriate HTTP status and error type
func (eh *ErrorHandler) determineErrorResponse(err error) (int, string) {
if appErr, ok := err.(*AppError); ok {
return appErr.StatusCode, eh.getErrorTypeFromCode(appErr.Code)
}
// For unknown errors, log as internal error but don't expose details
return http.StatusInternalServerError, "internal_error"
}
// sanitizeErrorMessage removes sensitive information from error messages
func (eh *ErrorHandler) sanitizeErrorMessage(userMessage string, err error) string {
if userMessage != "" {
return userMessage
}
// Provide generic messages for different error types
if appErr, ok := err.(*AppError); ok {
return eh.getGenericMessageFromCode(appErr.Code)
}
return "An error occurred"
}
// getErrorTypeFromCode converts an error code to a sanitized error type string
func (eh *ErrorHandler) getErrorTypeFromCode(code ErrorCode) string {
switch code {
case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat:
return "validation_error"
case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials:
return "authentication_failed"
case ErrForbidden, ErrInsufficientPermissions:
return "access_denied"
case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound:
return "resource_not_found"
case ErrAlreadyExists, ErrConflict:
return "resource_conflict"
case ErrRateLimit:
return "rate_limit_exceeded"
case ErrTimeout:
return "timeout"
default:
return "internal_error"
}
}
// getGenericMessageFromCode provides generic user-safe messages for error codes
func (eh *ErrorHandler) getGenericMessageFromCode(code ErrorCode) string {
switch code {
case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat:
return "Invalid input provided"
case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials:
return "Authentication required"
case ErrForbidden, ErrInsufficientPermissions:
return "Access denied"
case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound:
return "Resource not found"
case ErrAlreadyExists, ErrConflict:
return "Resource conflict"
case ErrRateLimit:
return "Rate limit exceeded"
case ErrTimeout:
return "Request timeout"
default:
return "An error occurred"
}
}
// getOrGenerateRequestID gets or generates a request ID for tracking
func (eh *ErrorHandler) getOrGenerateRequestID(c *gin.Context) string {
// Try to get existing request ID from context
if requestID, exists := c.Get("request_id"); exists {
if id, ok := requestID.(string); ok {
return id
}
}
// Try to get from header
requestID := c.GetHeader("X-Request-ID")
if requestID != "" {
return requestID
}
// Generate a simple request ID (in production, use a proper UUID library)
return generateSimpleID()
}
// generateSimpleID generates a simple request ID
func generateSimpleID() string {
// Simple implementation - in production use proper UUID generation
bytes := make([]byte, 8)
if _, err := rand.Read(bytes); err != nil {
// Fallback to timestamp-based ID
return fmt.Sprintf("req_%d", time.Now().UnixNano())
}
return "req_" + hex.EncodeToString(bytes)
}

View File

@ -7,15 +7,21 @@ import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/authorization"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/services"
"github.com/kms/api-key-service/internal/validation"
)
// ApplicationHandler handles application-related HTTP requests
type ApplicationHandler struct {
appService services.ApplicationService
authService services.AuthenticationService
logger *zap.Logger
appService services.ApplicationService
authService services.AuthenticationService
authzService *authorization.AuthorizationService
validator *validation.Validator
errorHandler *errors.ErrorHandler
logger *zap.Logger
}
// NewApplicationHandler creates a new application handler
@ -25,9 +31,12 @@ func NewApplicationHandler(
logger *zap.Logger,
) *ApplicationHandler {
return &ApplicationHandler{
appService: appService,
authService: authService,
logger: logger,
appService: appService,
authService: authService,
authzService: authorization.NewAuthorizationService(logger),
validator: validation.NewValidator(logger),
errorHandler: errors.NewErrorHandler(logger),
logger: logger,
}
}
@ -35,57 +44,99 @@ func NewApplicationHandler(
func (h *ApplicationHandler) Create(c *gin.Context) {
var req domain.CreateApplicationRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("Invalid request body", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Invalid request body: " + err.Error(),
})
h.errorHandler.HandleValidationError(c, "request_body", "Invalid application request format")
return
}
// Get user ID from context
userID, exists := c.Get("user_id")
if !exists {
h.logger.Error("User ID not found in context")
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Internal Server Error",
"message": "Authentication context not found",
})
// Get user ID from authenticated context
userID := h.getUserIDFromContext(c)
if userID == "" {
h.errorHandler.HandleAuthenticationError(c, errors.NewUnauthorizedError("User authentication required"))
return
}
app, err := h.appService.Create(c.Request.Context(), &req, userID.(string))
// Validate input
validationErrors := h.validator.ValidateApplicationRequest(req.AppID, req.AppLink, req.CallbackURL, []string{})
if len(validationErrors) > 0 {
h.logger.Warn("Application validation failed",
zap.String("user_id", userID),
zap.Any("errors", validationErrors))
h.errorHandler.HandleValidationError(c, "validation", "Invalid application data")
return
}
// Check authorization for creating applications
authCtx := &authorization.AuthorizationContext{
UserID: userID,
ResourceType: authorization.ResourceTypeApplication,
Action: authorization.ActionCreate,
}
if err := h.authzService.AuthorizeResourceAccess(c.Request.Context(), authCtx); err != nil {
h.errorHandler.HandleAuthorizationError(c, "application creation")
return
}
// Create the application
app, err := h.appService.Create(c.Request.Context(), &req, userID)
if err != nil {
h.logger.Error("Failed to create application", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Internal Server Error",
"message": "Failed to create application",
})
h.errorHandler.HandleInternalError(c, err)
return
}
h.logger.Info("Application created", zap.String("app_id", app.AppID))
h.logger.Info("Application created successfully",
zap.String("app_id", app.AppID),
zap.String("user_id", userID))
c.JSON(http.StatusCreated, app)
}
// getUserIDFromContext extracts user ID from Gin context
func (h *ApplicationHandler) getUserIDFromContext(c *gin.Context) string {
// Try to get from Gin context first (set by middleware)
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(string); ok {
return id
}
}
// Fallback to header (for compatibility)
userEmail := c.GetHeader("X-User-Email")
if userEmail != "" {
return userEmail
}
return ""
}
// GetByID handles GET /applications/:id
func (h *ApplicationHandler) GetByID(c *gin.Context) {
appID := c.Param("id")
if appID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Application ID is required",
})
// Get user ID from context
userID := h.getUserIDFromContext(c)
if userID == "" {
h.errorHandler.HandleAuthenticationError(c, errors.NewUnauthorizedError("User authentication required"))
return
}
// Validate app ID
if result := h.validator.ValidateAppID(appID); !result.Valid {
h.errorHandler.HandleValidationError(c, "app_id", "Invalid application ID")
return
}
// Get the application first
app, err := h.appService.GetByID(c.Request.Context(), appID)
if err != nil {
h.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", appID))
c.JSON(http.StatusNotFound, gin.H{
"error": "Not Found",
"message": "Application not found",
})
h.errorHandler.HandleError(c, err, "Application not found")
return
}
// Check authorization for reading this application
if err := h.authzService.AuthorizeApplicationOwnership(userID, app); err != nil {
h.errorHandler.HandleAuthorizationError(c, "application access")
return
}

View File

@ -1,32 +1,48 @@
package handlers
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/services"
)
// AuthHandler handles authentication-related HTTP requests
type AuthHandler struct {
authService services.AuthenticationService
tokenService services.TokenService
logger *zap.Logger
authService services.AuthenticationService
tokenService services.TokenService
headerValidator *auth.HeaderValidator
config config.ConfigProvider
errorHandler *errors.ErrorHandler
logger *zap.Logger
}
// NewAuthHandler creates a new auth handler
func NewAuthHandler(
authService services.AuthenticationService,
tokenService services.TokenService,
config config.ConfigProvider,
logger *zap.Logger,
) *AuthHandler {
return &AuthHandler{
authService: authService,
tokenService: tokenService,
logger: logger,
authService: authService,
tokenService: tokenService,
headerValidator: auth.NewHeaderValidator(config, logger),
config: config,
errorHandler: errors.NewErrorHandler(logger),
logger: logger,
}
}
@ -34,58 +50,81 @@ func NewAuthHandler(
func (h *AuthHandler) Login(c *gin.Context) {
var req domain.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("Invalid login request", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Invalid request body: " + err.Error(),
})
h.errorHandler.HandleValidationError(c, "request_body", "Invalid login request format")
return
}
// For now, we'll extract user ID from headers since we're using HeaderAuthenticationProvider
userID := c.GetHeader("X-User-Email")
if userID == "" {
h.logger.Warn("User email not found in headers")
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized",
"message": "User authentication required",
})
// Validate authentication headers with HMAC signature
userContext, err := h.headerValidator.ValidateAuthenticationHeaders(c.Request)
if err != nil {
h.errorHandler.HandleAuthenticationError(c, err)
return
}
h.logger.Info("Processing login request", zap.String("user_id", userID), zap.String("app_id", req.AppID))
h.logger.Info("Processing login request", zap.String("user_id", userContext.UserID), zap.String("app_id", req.AppID))
// Generate user token
token, err := h.tokenService.GenerateUserToken(c.Request.Context(), req.AppID, userID, req.Permissions)
token, err := h.tokenService.GenerateUserToken(c.Request.Context(), req.AppID, userContext.UserID, req.Permissions)
if err != nil {
h.logger.Error("Failed to generate user token", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Internal Server Error",
"message": "Failed to generate token",
})
h.errorHandler.HandleInternalError(c, err)
return
}
// For now, we'll just return the token directly
// In a real implementation, this would redirect to the callback URL
response := domain.LoginResponse{
RedirectURL: req.RedirectURI + "?token=" + token,
}
if req.RedirectURI == "" {
// If no redirect URI, return token directly
// If no redirect URI, return token directly via secure response body
c.JSON(http.StatusOK, gin.H{
"token": token,
"user_id": userID,
"user_id": userContext.UserID,
"app_id": req.AppID,
"expires_in": 604800, // 7 days in seconds
})
return
}
// For redirect flows, use secure cookie-based token delivery
// Set secure cookie with the token
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(
"auth_token", // name
token, // value
604800, // maxAge (7 days)
"/", // path
"", // domain (empty for current domain)
true, // secure (HTTPS only)
true, // httpOnly (no JavaScript access)
)
// Generate a secure state parameter for CSRF protection
state := h.generateSecureState(userContext.UserID, req.AppID)
// Redirect without token in URL
response := domain.LoginResponse{
RedirectURL: req.RedirectURI + "?state=" + state,
}
c.JSON(http.StatusOK, response)
}
// generateSecureState generates a secure state parameter for OAuth flows
func (h *AuthHandler) generateSecureState(userID, appID string) string {
// Generate random bytes for state
stateBytes := make([]byte, 16)
if _, err := rand.Read(stateBytes); err != nil {
h.logger.Error("Failed to generate random state", zap.Error(err))
// Fallback to less secure but functional state
return fmt.Sprintf("state_%s_%s_%d", userID, appID, time.Now().UnixNano())
}
// Create HMAC signature to prevent tampering
stateData := fmt.Sprintf("%s:%s:%x", userID, appID, stateBytes)
mac := hmac.New(sha256.New, []byte(h.config.GetString("AUTH_SIGNING_KEY")))
mac.Write([]byte(stateData))
signature := hex.EncodeToString(mac.Sum(nil))
// Return base64-encoded state with signature
return hex.EncodeToString([]byte(fmt.Sprintf("%s.%s", stateData, signature)))
}
// Verify handles POST /verify
func (h *AuthHandler) Verify(c *gin.Context) {
var req domain.VerifyRequest

235
internal/middleware/csrf.go Normal file
View File

@ -0,0 +1,235 @@
package middleware
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"net/http"
"strconv"
"strings"
"time"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
)
// CSRFMiddleware provides CSRF protection
type CSRFMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
}
// NewCSRFMiddleware creates a new CSRF middleware
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
return &CSRFMiddleware{
config: config,
logger: logger,
}
}
// CSRFProtection implements CSRF protection for state-changing operations
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF protection for safe methods
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
// Skip CSRF protection for specific endpoints that use other authentication
if cm.shouldSkipCSRF(r) {
next.ServeHTTP(w, r)
return
}
// Get CSRF token from header
csrfToken := r.Header.Get("X-CSRF-Token")
if csrfToken == "" {
cm.logger.Warn("Missing CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
return
}
// Validate CSRF token
if !cm.validateCSRFToken(csrfToken, r) {
cm.logger.Warn("Invalid CSRF token",
zap.String("path", r.URL.Path),
zap.String("method", r.Method),
zap.String("remote_addr", r.RemoteAddr))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
return
}
cm.logger.Debug("CSRF token validated successfully",
zap.String("path", r.URL.Path))
next.ServeHTTP(w, r)
})
}
// GenerateCSRFToken generates a new CSRF token for a user session
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
// Generate random bytes for token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
return "", err
}
// Create timestamp
timestamp := time.Now().Unix()
// Create token data
tokenData := hex.EncodeToString(tokenBytes)
// Create signing string: userID:timestamp:tokenData
timestampStr := strconv.FormatInt(timestamp, 10)
signingString := userID + ":" + timestampStr + ":" + tokenData
// Sign the token with HMAC
signature := cm.signData(signingString)
// Return encoded token: tokenData.timestamp.signature
token := tokenData + "." + timestampStr + "." + signature
return token, nil
}
// validateCSRFToken validates a CSRF token
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
// Parse token parts
parts := strings.Split(token, ".")
if len(parts) != 3 {
cm.logger.Debug("Invalid CSRF token format")
return false
}
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
// Get user ID from request context or headers
userID := cm.getUserIDFromRequest(r)
if userID == "" {
cm.logger.Debug("No user ID found for CSRF validation")
return false
}
// Recreate signing string
signingString := userID + ":" + timestampStr + ":" + tokenData
// Verify signature
expectedSignature := cm.signData(signingString)
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
cm.logger.Debug("CSRF token signature verification failed")
return false
}
// Parse timestamp
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
if err != nil {
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
return false
}
timestamp := time.Unix(timestampInt, 0)
// Check if token is expired (valid for 1 hour by default)
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
if maxAge <= 0 {
maxAge = 1 * time.Hour
}
if time.Since(timestamp) > maxAge {
cm.logger.Debug("CSRF token expired",
zap.Time("timestamp", timestamp),
zap.Duration("age", time.Since(timestamp)),
zap.Duration("max_age", maxAge))
return false
}
return true
}
// signData signs data with HMAC
func (cm *CSRFMiddleware) signData(data string) string {
// Use the same signing key as for authentication
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
if signingKey == "" {
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
return ""
}
mac := hmac.New(sha256.New, []byte(signingKey))
mac.Write([]byte(data))
return hex.EncodeToString(mac.Sum(nil))
}
// getUserIDFromRequest extracts user ID from request
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
// Try to get from X-User-Email header
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
if userEmail != "" {
return userEmail
}
// Try to get from context (if set by authentication middleware)
if userID := r.Context().Value("user_id"); userID != nil {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
// Skip for API endpoints that use API key authentication
if strings.HasPrefix(r.URL.Path, "/api/verify") {
return true
}
// Skip for health check endpoints
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
return true
}
// Skip for webhook endpoints (if any)
if strings.HasPrefix(r.URL.Path, "/webhook/") {
return true
}
return false
}
// SetCSRFCookie sets a secure CSRF token cookie
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
cookie := &http.Cookie{
Name: "csrf_token",
Value: token,
Path: "/",
MaxAge: 3600, // 1 hour
HttpOnly: false, // JavaScript needs to read this for AJAX requests
Secure: true, // HTTPS only
SameSite: http.SameSiteStrictMode,
}
http.SetCookie(w, cookie)
}
// GetCSRFTokenFromCookie gets CSRF token from cookie
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("csrf_token")
if err != nil {
return ""
}
return cookie.Value
}

View File

@ -23,23 +23,25 @@ import (
// SecurityMiddleware provides various security features
type SecurityMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
appRepo repository.ApplicationRepository
rateLimiters map[string]*rate.Limiter
mu sync.RWMutex
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
appRepo repository.ApplicationRepository
rateLimiters map[string]*rate.Limiter
authRateLimiters map[string]*rate.Limiter
mu sync.RWMutex
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
cacheManager := cache.NewCacheManager(config, logger)
return &SecurityMiddleware{
config: config,
logger: logger,
cacheManager: cacheManager,
appRepo: appRepo,
rateLimiters: make(map[string]*rate.Limiter),
config: config,
logger: logger,
cacheManager: cacheManager,
appRepo: appRepo,
rateLimiters: make(map[string]*rate.Limiter),
authRateLimiters: make(map[string]*rate.Limiter),
}
}
@ -76,6 +78,38 @@ func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler
})
}
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
next.ServeHTTP(w, r)
return
}
clientIP := s.getClientIP(r)
// Use stricter rate limits for auth endpoints
limiter := s.getAuthRateLimiter(clientIP)
// Check if request is allowed
if !limiter.Allow() {
s.logger.Warn("Auth rate limit exceeded",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
// Track authentication failures for brute force protection
s.TrackAuthenticationFailure(clientIP, "")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
return
}
next.ServeHTTP(w, r)
})
}
// BruteForceProtectionMiddleware implements brute force protection
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -231,6 +265,35 @@ func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
return limiter
}
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
s.mu.RLock()
limiter, exists := s.authRateLimiters[clientIP]
s.mu.RUnlock()
if exists {
return limiter
}
// Create new auth rate limiter with stricter limits
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
if authRPS <= 0 {
authRPS = 5 // Very strict default for auth endpoints
}
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
if authBurst <= 0 {
authBurst = 10 // Allow small bursts
}
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
s.mu.Lock()
s.authRateLimiters[clientIP] = limiter
s.mu.Unlock()
return limiter
}
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("rate_limit_violations", clientIP)

View File

@ -168,9 +168,6 @@ type DatabaseProvider interface {
// BeginTx starts a database transaction
BeginTx(ctx context.Context) (TransactionProvider, error)
// Migrate runs database migrations
Migrate(ctx context.Context, migrationPath string) error
}
// TransactionProvider defines the interface for database transaction operations

View File

@ -201,83 +201,118 @@ func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]
// Update updates an existing application
func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) {
// Build dynamic update query
// Build secure dynamic update query using a whitelist approach
var setParts []string
var args []interface{}
argIndex := 1
// Whitelist of allowed fields to prevent SQL injection
allowedFields := map[string]string{
"app_link": "app_link",
"type": "type",
"callback_url": "callback_url",
"hmac_key": "hmac_key",
"token_prefix": "token_prefix",
"token_renewal_duration": "token_renewal_duration",
"max_token_duration": "max_token_duration",
"owner_type": "owner_type",
"owner_name": "owner_name",
"owner_owner": "owner_owner",
}
if updates.AppLink != nil {
setParts = append(setParts, fmt.Sprintf("app_link = $%d", argIndex))
args = append(args, *updates.AppLink)
argIndex++
if field, ok := allowedFields["app_link"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.AppLink)
argIndex++
}
}
if updates.Type != nil {
typeStrings := make([]string, len(*updates.Type))
for i, t := range *updates.Type {
typeStrings[i] = string(t)
if field, ok := allowedFields["type"]; ok {
typeStrings := make([]string, len(*updates.Type))
for i, t := range *updates.Type {
typeStrings[i] = string(t)
}
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, pq.Array(typeStrings))
argIndex++
}
setParts = append(setParts, fmt.Sprintf("type = $%d", argIndex))
args = append(args, pq.Array(typeStrings))
argIndex++
}
if updates.CallbackURL != nil {
setParts = append(setParts, fmt.Sprintf("callback_url = $%d", argIndex))
args = append(args, *updates.CallbackURL)
argIndex++
if field, ok := allowedFields["callback_url"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.CallbackURL)
argIndex++
}
}
if updates.HMACKey != nil {
setParts = append(setParts, fmt.Sprintf("hmac_key = $%d", argIndex))
args = append(args, *updates.HMACKey)
argIndex++
if field, ok := allowedFields["hmac_key"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.HMACKey)
argIndex++
}
}
if updates.TokenPrefix != nil {
setParts = append(setParts, fmt.Sprintf("token_prefix = $%d", argIndex))
args = append(args, *updates.TokenPrefix)
argIndex++
if field, ok := allowedFields["token_prefix"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.TokenPrefix)
argIndex++
}
}
if updates.TokenRenewalDuration != nil {
setParts = append(setParts, fmt.Sprintf("token_renewal_duration = $%d", argIndex))
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
argIndex++
if field, ok := allowedFields["token_renewal_duration"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
argIndex++
}
}
if updates.MaxTokenDuration != nil {
setParts = append(setParts, fmt.Sprintf("max_token_duration = $%d", argIndex))
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
argIndex++
if field, ok := allowedFields["max_token_duration"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
argIndex++
}
}
if updates.Owner != nil {
setParts = append(setParts, fmt.Sprintf("owner_type = $%d", argIndex))
args = append(args, string(updates.Owner.Type))
argIndex++
if field, ok := allowedFields["owner_type"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, string(updates.Owner.Type))
argIndex++
}
setParts = append(setParts, fmt.Sprintf("owner_name = $%d", argIndex))
args = append(args, updates.Owner.Name)
argIndex++
if field, ok := allowedFields["owner_name"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.Owner.Name)
argIndex++
}
setParts = append(setParts, fmt.Sprintf("owner_owner = $%d", argIndex))
args = append(args, updates.Owner.Owner)
argIndex++
if field, ok := allowedFields["owner_owner"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.Owner.Owner)
argIndex++
}
}
if len(setParts) == 0 {
return r.GetByID(ctx, appID) // No updates, return current state
}
// Always update the updated_at field
// Always update the updated_at field - using literal field name for security
setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex))
args = append(args, time.Now())
argIndex++
// Add WHERE clause
// Add WHERE clause parameter
args = append(args, appID)
// Build the final query with properly parameterized placeholders
query := fmt.Sprintf(`
UPDATE applications
SET %s

View File

@ -0,0 +1,375 @@
package validation
import (
"fmt"
"net/url"
"regexp"
"strings"
"unicode"
"go.uber.org/zap"
)
// Validator provides comprehensive input validation
type Validator struct {
logger *zap.Logger
}
// NewValidator creates a new input validator
func NewValidator(logger *zap.Logger) *Validator {
return &Validator{
logger: logger,
}
}
// ValidationError represents a validation error
type ValidationError struct {
Field string `json:"field"`
Message string `json:"message"`
Value string `json:"value,omitempty"`
}
func (e ValidationError) Error() string {
return fmt.Sprintf("validation error for field '%s': %s", e.Field, e.Message)
}
// ValidationResult holds the result of validation
type ValidationResult struct {
Valid bool `json:"valid"`
Errors []ValidationError `json:"errors"`
}
// AddError adds a validation error
func (vr *ValidationResult) AddError(field, message, value string) {
vr.Valid = false
vr.Errors = append(vr.Errors, ValidationError{
Field: field,
Message: message,
Value: value,
})
}
// Regular expressions for validation
var (
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
appIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$`)
tokenPrefixRegex = regexp.MustCompile(`^[A-Z]{2,4}$`)
permissionRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9._]*[a-zA-Z0-9]$`)
)
// ValidateEmail validates email addresses
func (v *Validator) ValidateEmail(email string) *ValidationResult {
result := &ValidationResult{Valid: true}
if email == "" {
result.AddError("email", "Email is required", "")
return result
}
if len(email) > 254 {
result.AddError("email", "Email too long (max 254 characters)", email)
return result
}
if !emailRegex.MatchString(email) {
result.AddError("email", "Invalid email format", email)
return result
}
// Additional email security checks
if strings.Contains(email, "..") {
result.AddError("email", "Email contains consecutive dots", email)
return result
}
// Check for potentially dangerous characters
dangerousChars := []string{"<", ">", "\"", "'", "&", ";", "|", "`"}
for _, char := range dangerousChars {
if strings.Contains(email, char) {
result.AddError("email", "Email contains invalid characters", email)
return result
}
}
return result
}
// ValidateAppID validates application IDs
func (v *Validator) ValidateAppID(appID string) *ValidationResult {
result := &ValidationResult{Valid: true}
if appID == "" {
result.AddError("app_id", "Application ID is required", "")
return result
}
if len(appID) < 3 || len(appID) > 100 {
result.AddError("app_id", "Application ID must be between 3 and 100 characters", appID)
return result
}
if !appIDRegex.MatchString(appID) {
result.AddError("app_id", "Application ID must start and end with alphanumeric characters and contain only letters, numbers, dots, hyphens, and underscores", appID)
return result
}
// Check for reserved names
reservedNames := []string{"admin", "root", "system", "internal", "api", "www", "mail", "ftp"}
for _, reserved := range reservedNames {
if strings.EqualFold(appID, reserved) {
result.AddError("app_id", "Application ID cannot be a reserved name", appID)
return result
}
}
return result
}
// ValidateURL validates URLs
func (v *Validator) ValidateURL(urlStr, fieldName string) *ValidationResult {
result := &ValidationResult{Valid: true}
if urlStr == "" {
result.AddError(fieldName, "URL is required", "")
return result
}
if len(urlStr) > 2000 {
result.AddError(fieldName, "URL too long (max 2000 characters)", urlStr)
return result
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
result.AddError(fieldName, "Invalid URL format", urlStr)
return result
}
// Validate scheme
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
result.AddError(fieldName, "URL must use http or https scheme", urlStr)
return result
}
// Security: Require HTTPS in production (configurable)
if parsedURL.Scheme != "https" {
v.logger.Warn("Non-HTTPS URL provided", zap.String("url", urlStr))
// In strict mode, this would be an error
// result.AddError(fieldName, "HTTPS is required", urlStr)
}
// Validate host
if parsedURL.Host == "" {
result.AddError(fieldName, "URL must have a valid host", urlStr)
return result
}
// Security: Block localhost and private IPs in production
if v.isPrivateOrLocalhost(parsedURL.Host) {
result.AddError(fieldName, "URLs pointing to private or localhost addresses are not allowed", urlStr)
return result
}
return result
}
// ValidatePermissions validates a list of permissions
func (v *Validator) ValidatePermissions(permissions []string) *ValidationResult {
result := &ValidationResult{Valid: true}
if len(permissions) == 0 {
result.AddError("permissions", "At least one permission is required", "")
return result
}
if len(permissions) > 50 {
result.AddError("permissions", "Too many permissions (max 50)", fmt.Sprintf("%d", len(permissions)))
return result
}
seen := make(map[string]bool)
for i, permission := range permissions {
field := fmt.Sprintf("permissions[%d]", i)
// Check for duplicates
if seen[permission] {
result.AddError(field, "Duplicate permission", permission)
continue
}
seen[permission] = true
// Validate individual permission
if err := v.validateSinglePermission(permission); err != nil {
result.AddError(field, err.Error(), permission)
}
}
return result
}
// ValidateTokenPrefix validates token prefixes
func (v *Validator) ValidateTokenPrefix(prefix string) *ValidationResult {
result := &ValidationResult{Valid: true}
if prefix == "" {
// Empty prefix is allowed - will use default
return result
}
if len(prefix) < 2 || len(prefix) > 4 {
result.AddError("token_prefix", "Token prefix must be between 2 and 4 characters", prefix)
return result
}
if !tokenPrefixRegex.MatchString(prefix) {
result.AddError("token_prefix", "Token prefix must contain only uppercase letters", prefix)
return result
}
return result
}
// ValidateString validates a general string with length and content constraints
func (v *Validator) ValidateString(value, fieldName string, minLen, maxLen int, allowEmpty bool) *ValidationResult {
result := &ValidationResult{Valid: true}
if value == "" && !allowEmpty {
result.AddError(fieldName, fmt.Sprintf("%s is required", fieldName), "")
return result
}
if len(value) < minLen {
result.AddError(fieldName, fmt.Sprintf("%s must be at least %d characters", fieldName, minLen), value)
return result
}
if len(value) > maxLen {
result.AddError(fieldName, fmt.Sprintf("%s must be at most %d characters", fieldName, maxLen), value)
return result
}
// Check for control characters and other potentially dangerous characters
for i, r := range value {
if unicode.IsControl(r) && r != '\n' && r != '\r' && r != '\t' {
result.AddError(fieldName, fmt.Sprintf("%s contains invalid control character at position %d", fieldName, i), value)
return result
}
}
// Check for null bytes
if strings.Contains(value, "\x00") {
result.AddError(fieldName, fmt.Sprintf("%s contains null bytes", fieldName), value)
return result
}
return result
}
// ValidateDuration validates duration strings
func (v *Validator) ValidateDuration(duration, fieldName string) *ValidationResult {
result := &ValidationResult{Valid: true}
if duration == "" {
result.AddError(fieldName, "Duration is required", "")
return result
}
// Basic duration format validation (Go duration format)
durationRegex := regexp.MustCompile(`^(\d+(\.\d+)?(ns|us|µs|ms|s|m|h))+$`)
if !durationRegex.MatchString(duration) {
result.AddError(fieldName, "Invalid duration format (use Go duration format like '1h', '30m', '5s')", duration)
return result
}
return result
}
// Helper methods
func (v *Validator) validateSinglePermission(permission string) error {
if permission == "" {
return fmt.Errorf("permission cannot be empty")
}
if len(permission) > 100 {
return fmt.Errorf("permission too long (max 100 characters)")
}
if !permissionRegex.MatchString(permission) {
return fmt.Errorf("permission must start and end with alphanumeric characters and contain only letters, numbers, dots, and underscores")
}
// Validate permission hierarchy (dots separate levels)
parts := strings.Split(permission, ".")
for i, part := range parts {
if part == "" {
return fmt.Errorf("permission level %d is empty", i+1)
}
if len(part) > 50 {
return fmt.Errorf("permission level %d is too long (max 50 characters)", i+1)
}
}
if len(parts) > 5 {
return fmt.Errorf("permission hierarchy too deep (max 5 levels)")
}
return nil
}
func (v *Validator) isPrivateOrLocalhost(host string) bool {
// Remove port if present
if colonIndex := strings.LastIndex(host, ":"); colonIndex != -1 {
host = host[:colonIndex]
}
// Check for localhost variants
localhosts := []string{"localhost", "127.0.0.1", "::1", "0.0.0.0"}
for _, localhost := range localhosts {
if strings.EqualFold(host, localhost) {
return true
}
}
// Check for private IP ranges (simplified)
privateRanges := []string{
"10.", "192.168.", "172.16.", "172.17.", "172.18.", "172.19.",
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.", "172.25.",
"172.26.", "172.27.", "172.28.", "172.29.", "172.30.", "172.31.",
}
for _, privateRange := range privateRanges {
if strings.HasPrefix(host, privateRange) {
return true
}
}
return false
}
// ValidateApplicationRequest validates create/update application requests
func (v *Validator) ValidateApplicationRequest(appID, appLink, callbackURL string, permissions []string) []ValidationError {
var errors []ValidationError
// Validate app ID
if result := v.ValidateAppID(appID); !result.Valid {
errors = append(errors, result.Errors...)
}
// Validate app link URL
if result := v.ValidateURL(appLink, "app_link"); !result.Valid {
errors = append(errors, result.Errors...)
}
// Validate callback URL
if result := v.ValidateURL(callbackURL, "callback_url"); !result.Valid {
errors = append(errors, result.Errors...)
}
// Validate permissions
if result := v.ValidatePermissions(permissions); !result.Valid {
errors = append(errors, result.Errors...)
}
return errors
}