-
This commit is contained in:
171
internal/auth/header_validator.go
Normal file
171
internal/auth/header_validator.go
Normal file
@ -0,0 +1,171 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// HeaderValidator provides secure validation of authentication headers
|
||||
type HeaderValidator struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewHeaderValidator creates a new header validator
|
||||
func NewHeaderValidator(config config.ConfigProvider, logger *zap.Logger) *HeaderValidator {
|
||||
return &HeaderValidator{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatedUserContext holds validated user information
|
||||
type ValidatedUserContext struct {
|
||||
UserID string
|
||||
Email string
|
||||
Timestamp time.Time
|
||||
Signature string
|
||||
}
|
||||
|
||||
// ValidateAuthenticationHeaders validates user authentication headers with HMAC signature
|
||||
func (hv *HeaderValidator) ValidateAuthenticationHeaders(r *http.Request) (*ValidatedUserContext, error) {
|
||||
userEmail := r.Header.Get(hv.config.GetString("AUTH_HEADER_USER_EMAIL"))
|
||||
timestamp := r.Header.Get("X-Auth-Timestamp")
|
||||
signature := r.Header.Get("X-Auth-Signature")
|
||||
|
||||
if userEmail == "" {
|
||||
hv.logger.Warn("Missing user email header")
|
||||
return nil, errors.NewAuthenticationError("User authentication required")
|
||||
}
|
||||
|
||||
if timestamp == "" || signature == "" {
|
||||
hv.logger.Warn("Missing authentication signature headers",
|
||||
zap.String("user_email", userEmail))
|
||||
return nil, errors.NewAuthenticationError("Authentication signature required")
|
||||
}
|
||||
|
||||
// Validate timestamp (prevent replay attacks)
|
||||
timestampInt, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil {
|
||||
hv.logger.Warn("Invalid timestamp format",
|
||||
zap.String("timestamp", timestamp),
|
||||
zap.String("user_email", userEmail))
|
||||
return nil, errors.NewAuthenticationError("Invalid timestamp format")
|
||||
}
|
||||
|
||||
timestampTime := time.Unix(timestampInt, 0)
|
||||
now := time.Now()
|
||||
|
||||
// Allow 5 minutes clock skew
|
||||
maxAge := 5 * time.Minute
|
||||
if now.Sub(timestampTime) > maxAge || timestampTime.After(now.Add(1*time.Minute)) {
|
||||
hv.logger.Warn("Timestamp outside acceptable window",
|
||||
zap.Time("timestamp", timestampTime),
|
||||
zap.Time("now", now),
|
||||
zap.String("user_email", userEmail))
|
||||
return nil, errors.NewAuthenticationError("Request timestamp outside acceptable window")
|
||||
}
|
||||
|
||||
// Validate HMAC signature
|
||||
if !hv.validateSignature(userEmail, timestamp, signature) {
|
||||
hv.logger.Warn("Invalid authentication signature",
|
||||
zap.String("user_email", userEmail))
|
||||
return nil, errors.NewAuthenticationError("Invalid authentication signature")
|
||||
}
|
||||
|
||||
// Validate email format
|
||||
if !hv.isValidEmail(userEmail) {
|
||||
hv.logger.Warn("Invalid email format",
|
||||
zap.String("user_email", userEmail))
|
||||
return nil, errors.NewAuthenticationError("Invalid email format")
|
||||
}
|
||||
|
||||
hv.logger.Debug("Authentication headers validated successfully",
|
||||
zap.String("user_email", userEmail))
|
||||
|
||||
return &ValidatedUserContext{
|
||||
UserID: userEmail,
|
||||
Email: userEmail,
|
||||
Timestamp: timestampTime,
|
||||
Signature: signature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateSignature validates the HMAC signature
|
||||
func (hv *HeaderValidator) validateSignature(userEmail, timestamp, signature string) bool {
|
||||
// Get the signing key from config
|
||||
signingKey := hv.config.GetString("AUTH_SIGNING_KEY")
|
||||
if signingKey == "" {
|
||||
hv.logger.Error("AUTH_SIGNING_KEY not configured")
|
||||
return false
|
||||
}
|
||||
|
||||
// Create the signing string
|
||||
signingString := fmt.Sprintf("%s:%s", userEmail, timestamp)
|
||||
|
||||
// Calculate expected signature
|
||||
mac := hmac.New(sha256.New, []byte(signingKey))
|
||||
mac.Write([]byte(signingString))
|
||||
expectedSignature := hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
// Use constant-time comparison to prevent timing attacks
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSignature))
|
||||
}
|
||||
|
||||
// isValidEmail performs basic email validation
|
||||
func (hv *HeaderValidator) isValidEmail(email string) bool {
|
||||
if len(email) == 0 || len(email) > 254 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Basic email validation - contains @ and has valid structure
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
local, domain := parts[0], parts[1]
|
||||
|
||||
// Local part validation
|
||||
if len(local) == 0 || len(local) > 64 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Domain part validation
|
||||
if len(domain) == 0 || len(domain) > 253 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !strings.Contains(domain, ".") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for invalid characters (basic check)
|
||||
invalidChars := []string{" ", "..", "@@", "<", ">", "\"", "'"}
|
||||
for _, char := range invalidChars {
|
||||
if strings.Contains(email, char) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GenerateSignatureExample generates an example signature for documentation
|
||||
func (hv *HeaderValidator) GenerateSignatureExample(userEmail string, timestamp string, signingKey string) string {
|
||||
signingString := fmt.Sprintf("%s:%s", userEmail, timestamp)
|
||||
mac := hmac.New(sha256.New, []byte(signingKey))
|
||||
mac.Write([]byte(signingString))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
@ -57,6 +57,12 @@ func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error)
|
||||
return "", errors.NewValidationError("JWT secret not configured")
|
||||
}
|
||||
|
||||
// Generate secure JWT ID
|
||||
jti := j.generateJTI()
|
||||
if jti == "" {
|
||||
return "", errors.NewInternalError("Failed to generate secure JWT ID - cryptographic random number generation failed")
|
||||
}
|
||||
|
||||
// Create custom claims
|
||||
claims := CustomClaims{
|
||||
UserID: userToken.UserID,
|
||||
@ -72,7 +78,7 @@ func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error)
|
||||
ExpiresAt: jwt.NewNumericDate(userToken.ExpiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(userToken.IssuedAt),
|
||||
NotBefore: jwt.NewNumericDate(userToken.IssuedAt),
|
||||
ID: j.generateJTI(),
|
||||
ID: jti,
|
||||
},
|
||||
}
|
||||
|
||||
@ -272,8 +278,10 @@ func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) {
|
||||
func (j *JWTManager) generateJTI() string {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
// Fallback to timestamp-based ID if random generation fails
|
||||
return fmt.Sprintf("jti_%d", time.Now().UnixNano())
|
||||
// Log the error and fail securely - do not generate predictable fallback IDs
|
||||
j.logger.Error("Cryptographic random number generation failed - cannot generate secure JWT ID", zap.Error(err))
|
||||
// Return an error indicator that will cause token generation to fail
|
||||
return ""
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
353
internal/authorization/rbac.go
Normal file
353
internal/authorization/rbac.go
Normal file
@ -0,0 +1,353 @@
|
||||
package authorization
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// ResourceType represents different types of resources
|
||||
type ResourceType string
|
||||
|
||||
const (
|
||||
ResourceTypeApplication ResourceType = "application"
|
||||
ResourceTypeToken ResourceType = "token"
|
||||
ResourceTypePermission ResourceType = "permission"
|
||||
ResourceTypeUser ResourceType = "user"
|
||||
)
|
||||
|
||||
// Action represents different actions that can be performed
|
||||
type Action string
|
||||
|
||||
const (
|
||||
ActionRead Action = "read"
|
||||
ActionWrite Action = "write"
|
||||
ActionDelete Action = "delete"
|
||||
ActionCreate Action = "create"
|
||||
)
|
||||
|
||||
// AuthorizationContext holds context for authorization decisions
|
||||
type AuthorizationContext struct {
|
||||
UserID string
|
||||
UserEmail string
|
||||
ResourceType ResourceType
|
||||
ResourceID string
|
||||
Action Action
|
||||
OwnerInfo *domain.Owner
|
||||
}
|
||||
|
||||
// AuthorizationService provides role-based access control
|
||||
type AuthorizationService struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuthorizationService creates a new authorization service
|
||||
func NewAuthorizationService(logger *zap.Logger) *AuthorizationService {
|
||||
return &AuthorizationService{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeResourceAccess checks if a user can perform an action on a resource
|
||||
func (a *AuthorizationService) AuthorizeResourceAccess(ctx context.Context, authCtx *AuthorizationContext) error {
|
||||
if authCtx == nil {
|
||||
return errors.NewForbiddenError("Authorization context is required")
|
||||
}
|
||||
|
||||
a.logger.Debug("Authorizing resource access",
|
||||
zap.String("user_id", authCtx.UserID),
|
||||
zap.String("resource_type", string(authCtx.ResourceType)),
|
||||
zap.String("resource_id", authCtx.ResourceID),
|
||||
zap.String("action", string(authCtx.Action)))
|
||||
|
||||
// Check if user is a system admin
|
||||
if a.isSystemAdmin(authCtx.UserID) {
|
||||
a.logger.Debug("System admin access granted", zap.String("user_id", authCtx.UserID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check resource ownership
|
||||
if authCtx.OwnerInfo != nil {
|
||||
if a.isResourceOwner(authCtx, authCtx.OwnerInfo) {
|
||||
a.logger.Debug("Resource owner access granted",
|
||||
zap.String("user_id", authCtx.UserID),
|
||||
zap.String("resource_id", authCtx.ResourceID))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check specific resource-action combinations
|
||||
switch authCtx.ResourceType {
|
||||
case ResourceTypeApplication:
|
||||
return a.authorizeApplicationAccess(authCtx)
|
||||
case ResourceTypeToken:
|
||||
return a.authorizeTokenAccess(authCtx)
|
||||
case ResourceTypePermission:
|
||||
return a.authorizePermissionAccess(authCtx)
|
||||
case ResourceTypeUser:
|
||||
return a.authorizeUserAccess(authCtx)
|
||||
default:
|
||||
return errors.NewForbiddenError(fmt.Sprintf("Unknown resource type: %s", authCtx.ResourceType))
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeApplicationOwnership checks if a user owns an application
|
||||
func (a *AuthorizationService) AuthorizeApplicationOwnership(userID string, app *domain.Application) error {
|
||||
if app == nil {
|
||||
return errors.NewValidationError("Application is required")
|
||||
}
|
||||
|
||||
// System admins can access any application
|
||||
if a.isSystemAdmin(userID) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if user is the owner
|
||||
if a.isOwner(userID, &app.Owner) {
|
||||
return nil
|
||||
}
|
||||
|
||||
a.logger.Warn("Application ownership authorization failed",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", app.AppID),
|
||||
zap.String("owner_type", string(app.Owner.Type)),
|
||||
zap.String("owner_name", app.Owner.Name))
|
||||
|
||||
return errors.NewForbiddenError("You do not have permission to access this application")
|
||||
}
|
||||
|
||||
// AuthorizeTokenOwnership checks if a user owns a token
|
||||
func (a *AuthorizationService) AuthorizeTokenOwnership(userID string, token interface{}) error {
|
||||
// System admins can access any token
|
||||
if a.isSystemAdmin(userID) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract owner information based on token type
|
||||
var owner *domain.Owner
|
||||
var tokenID string
|
||||
|
||||
switch t := token.(type) {
|
||||
case *domain.StaticToken:
|
||||
owner = &t.Owner
|
||||
tokenID = t.ID.String()
|
||||
case *domain.UserToken:
|
||||
// For user tokens, the user ID should match
|
||||
if t.UserID == userID {
|
||||
return nil
|
||||
}
|
||||
tokenID = "user_token"
|
||||
default:
|
||||
return errors.NewValidationError("Unknown token type")
|
||||
}
|
||||
|
||||
// Check ownership
|
||||
if owner != nil && a.isOwner(userID, owner) {
|
||||
return nil
|
||||
}
|
||||
|
||||
a.logger.Warn("Token ownership authorization failed",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("token_id", tokenID))
|
||||
|
||||
return errors.NewForbiddenError("You do not have permission to access this token")
|
||||
}
|
||||
|
||||
// isSystemAdmin checks if a user is a system administrator
|
||||
func (a *AuthorizationService) isSystemAdmin(userID string) bool {
|
||||
// System admin users - this should be configurable
|
||||
systemAdmins := []string{
|
||||
"admin@example.com",
|
||||
"system@internal.com",
|
||||
}
|
||||
|
||||
for _, admin := range systemAdmins {
|
||||
if userID == admin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isResourceOwner checks if the user is the owner of a resource
|
||||
func (a *AuthorizationService) isResourceOwner(authCtx *AuthorizationContext, owner *domain.Owner) bool {
|
||||
return a.isOwner(authCtx.UserID, owner)
|
||||
}
|
||||
|
||||
// isOwner checks if a user is the owner based on owner information
|
||||
func (a *AuthorizationService) isOwner(userID string, owner *domain.Owner) bool {
|
||||
switch owner.Type {
|
||||
case domain.OwnerTypeIndividual:
|
||||
// For individual ownership, check if the user ID matches the owner name
|
||||
return userID == owner.Name || userID == owner.Owner
|
||||
case domain.OwnerTypeTeam:
|
||||
// For team ownership, this would typically require a team membership check
|
||||
// For now, we'll check if the user is the team owner
|
||||
return userID == owner.Owner || a.isTeamMember(userID, owner.Name)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isTeamMember checks if a user is a member of a team (placeholder implementation)
|
||||
func (a *AuthorizationService) isTeamMember(userID, teamName string) bool {
|
||||
// In a real implementation, this would check team membership in a database
|
||||
// For now, we'll use a simple heuristic based on email domains
|
||||
|
||||
if !strings.Contains(userID, "@") {
|
||||
return false
|
||||
}
|
||||
|
||||
userDomain := strings.Split(userID, "@")[1]
|
||||
teamDomain := strings.ToLower(teamName)
|
||||
|
||||
// Simple check: if team name looks like a domain and user's domain matches
|
||||
if strings.Contains(teamDomain, ".") && strings.Contains(userDomain, teamDomain) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Additional team membership logic would go here
|
||||
return false
|
||||
}
|
||||
|
||||
// authorizeApplicationAccess handles application-specific authorization
|
||||
func (a *AuthorizationService) authorizeApplicationAccess(authCtx *AuthorizationContext) error {
|
||||
switch authCtx.Action {
|
||||
case ActionRead:
|
||||
// Users can read applications they have some relationship with
|
||||
// This could be expanded to check for shared access, etc.
|
||||
return errors.NewForbiddenError("You do not have permission to read this application")
|
||||
case ActionWrite:
|
||||
// Only owners can modify applications
|
||||
return errors.NewForbiddenError("You do not have permission to modify this application")
|
||||
case ActionDelete:
|
||||
// Only owners can delete applications
|
||||
return errors.NewForbiddenError("You do not have permission to delete this application")
|
||||
case ActionCreate:
|
||||
// Most users can create applications (with rate limiting)
|
||||
return nil
|
||||
default:
|
||||
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeTokenAccess handles token-specific authorization
|
||||
func (a *AuthorizationService) authorizeTokenAccess(authCtx *AuthorizationContext) error {
|
||||
switch authCtx.Action {
|
||||
case ActionRead:
|
||||
return errors.NewForbiddenError("You do not have permission to read this token")
|
||||
case ActionWrite:
|
||||
return errors.NewForbiddenError("You do not have permission to modify this token")
|
||||
case ActionDelete:
|
||||
return errors.NewForbiddenError("You do not have permission to delete this token")
|
||||
case ActionCreate:
|
||||
return errors.NewForbiddenError("You do not have permission to create tokens for this application")
|
||||
default:
|
||||
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
|
||||
}
|
||||
}
|
||||
|
||||
// authorizePermissionAccess handles permission-specific authorization
|
||||
func (a *AuthorizationService) authorizePermissionAccess(authCtx *AuthorizationContext) error {
|
||||
switch authCtx.Action {
|
||||
case ActionRead:
|
||||
// Users can read permissions they have
|
||||
return nil
|
||||
case ActionWrite:
|
||||
// Only admins can modify permissions
|
||||
return errors.NewForbiddenError("You do not have permission to modify permissions")
|
||||
case ActionDelete:
|
||||
// Only admins can delete permissions
|
||||
return errors.NewForbiddenError("You do not have permission to delete permissions")
|
||||
case ActionCreate:
|
||||
// Only admins can create permissions
|
||||
return errors.NewForbiddenError("You do not have permission to create permissions")
|
||||
default:
|
||||
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeUserAccess handles user-specific authorization
|
||||
func (a *AuthorizationService) authorizeUserAccess(authCtx *AuthorizationContext) error {
|
||||
switch authCtx.Action {
|
||||
case ActionRead:
|
||||
// Users can read their own information
|
||||
if authCtx.ResourceID == authCtx.UserID {
|
||||
return nil
|
||||
}
|
||||
return errors.NewForbiddenError("You do not have permission to read this user's information")
|
||||
case ActionWrite:
|
||||
// Users can modify their own information
|
||||
if authCtx.ResourceID == authCtx.UserID {
|
||||
return nil
|
||||
}
|
||||
return errors.NewForbiddenError("You do not have permission to modify this user's information")
|
||||
case ActionDelete:
|
||||
// Users can delete their own account, admins can delete any
|
||||
if authCtx.ResourceID == authCtx.UserID {
|
||||
return nil
|
||||
}
|
||||
return errors.NewForbiddenError("You do not have permission to delete this user")
|
||||
default:
|
||||
return errors.NewForbiddenError(fmt.Sprintf("Unknown action: %s", authCtx.Action))
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeListAccess checks if a user can list resources of a specific type
|
||||
func (a *AuthorizationService) AuthorizeListAccess(ctx context.Context, userID string, resourceType ResourceType) error {
|
||||
a.logger.Debug("Authorizing list access",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("resource_type", string(resourceType)))
|
||||
|
||||
// System admins can list anything
|
||||
if a.isSystemAdmin(userID) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For now, allow users to list their own resources
|
||||
// This would be refined based on business requirements
|
||||
switch resourceType {
|
||||
case ResourceTypeApplication:
|
||||
return nil // Users can list applications (filtered by ownership)
|
||||
case ResourceTypeToken:
|
||||
return nil // Users can list their own tokens
|
||||
case ResourceTypePermission:
|
||||
return nil // Users can list available permissions
|
||||
case ResourceTypeUser:
|
||||
// Only admins can list users
|
||||
return errors.NewForbiddenError("You do not have permission to list users")
|
||||
default:
|
||||
return errors.NewForbiddenError(fmt.Sprintf("Unknown resource type: %s", resourceType))
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserResourceFilter returns a filter for resources that a user can access
|
||||
func (a *AuthorizationService) GetUserResourceFilter(userID string, resourceType ResourceType) map[string]interface{} {
|
||||
filter := make(map[string]interface{})
|
||||
|
||||
// System admins see everything
|
||||
if a.isSystemAdmin(userID) {
|
||||
return filter // Empty filter means no restrictions
|
||||
}
|
||||
|
||||
// Filter by ownership
|
||||
switch resourceType {
|
||||
case ResourceTypeApplication, ResourceTypeToken:
|
||||
// Users can only see resources they own
|
||||
filter["owner_email"] = userID
|
||||
case ResourceTypePermission:
|
||||
// Users can see all permissions (they're not user-specific)
|
||||
return filter
|
||||
case ResourceTypeUser:
|
||||
// Users can only see themselves
|
||||
filter["user_id"] = userID
|
||||
}
|
||||
|
||||
return filter
|
||||
}
|
||||
@ -36,6 +36,9 @@ type ConfigProvider interface {
|
||||
// GetDatabaseDSN constructs and returns the database connection string
|
||||
GetDatabaseDSN() string
|
||||
|
||||
// GetDatabaseDSNForLogging returns a sanitized database connection string safe for logging
|
||||
GetDatabaseDSNForLogging() string
|
||||
|
||||
// GetServerAddress returns the server address in host:port format
|
||||
GetServerAddress() string
|
||||
|
||||
@ -104,17 +107,20 @@ func (c *Config) setDefaults() {
|
||||
"RATE_LIMIT_ENABLED": "true",
|
||||
"RATE_LIMIT_RPS": "100",
|
||||
"RATE_LIMIT_BURST": "200",
|
||||
"AUTH_RATE_LIMIT_RPS": "5",
|
||||
"AUTH_RATE_LIMIT_BURST": "10",
|
||||
"CACHE_ENABLED": "false",
|
||||
"CACHE_TTL": "1h",
|
||||
"JWT_ISSUER": "api-key-service",
|
||||
"JWT_SECRET": "bootstrap-jwt-secret-change-in-production",
|
||||
"JWT_SECRET": "", // Must be set via environment variable
|
||||
"AUTH_PROVIDER": "header", // header or sso
|
||||
"AUTH_HEADER_USER_EMAIL": "X-User-Email",
|
||||
"AUTH_SIGNING_KEY": "", // Must be set via environment variable
|
||||
"SSO_PROVIDER_URL": "",
|
||||
"SSO_CLIENT_ID": "",
|
||||
"SSO_CLIENT_SECRET": "",
|
||||
"INTERNAL_APP_ID": "internal.api-key-service",
|
||||
"INTERNAL_HMAC_KEY": "bootstrap-hmac-key-change-in-production",
|
||||
"INTERNAL_HMAC_KEY": "", // Must be set via environment variable
|
||||
"METRICS_ENABLED": "false",
|
||||
"METRICS_PORT": "9090",
|
||||
"REDIS_ENABLED": "false",
|
||||
@ -131,6 +137,8 @@ func (c *Config) setDefaults() {
|
||||
"AUTH_FAILURE_WINDOW": "15m",
|
||||
"IP_BLOCK_DURATION": "1h",
|
||||
"REQUEST_MAX_AGE": "5m",
|
||||
"CSRF_TOKEN_MAX_AGE": "1h",
|
||||
"BCRYPT_COST": "14",
|
||||
"IP_WHITELIST": "",
|
||||
"SAML_ENABLED": "false",
|
||||
"SAML_IDP_METADATA_URL": "",
|
||||
@ -212,6 +220,7 @@ func (c *Config) Validate() error {
|
||||
"INTERNAL_APP_ID",
|
||||
"INTERNAL_HMAC_KEY",
|
||||
"JWT_SECRET",
|
||||
"AUTH_SIGNING_KEY",
|
||||
}
|
||||
|
||||
var missing []string
|
||||
@ -225,6 +234,22 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("missing required configuration keys: %s", strings.Join(missing, ", "))
|
||||
}
|
||||
|
||||
// Validate that production secrets are not using default values
|
||||
jwtSecret := c.GetString("JWT_SECRET")
|
||||
if jwtSecret == "bootstrap-jwt-secret-change-in-production" || len(jwtSecret) < 32 {
|
||||
return fmt.Errorf("JWT_SECRET must be set to a secure value (minimum 32 characters)")
|
||||
}
|
||||
|
||||
hmacKey := c.GetString("INTERNAL_HMAC_KEY")
|
||||
if hmacKey == "bootstrap-hmac-key-change-in-production" || len(hmacKey) < 32 {
|
||||
return fmt.Errorf("INTERNAL_HMAC_KEY must be set to a secure value (minimum 32 characters)")
|
||||
}
|
||||
|
||||
authSigningKey := c.GetString("AUTH_SIGNING_KEY")
|
||||
if len(authSigningKey) < 32 {
|
||||
return fmt.Errorf("AUTH_SIGNING_KEY must be set to a secure value (minimum 32 characters)")
|
||||
}
|
||||
|
||||
// Validate specific values
|
||||
if c.GetInt("DB_PORT") <= 0 || c.GetInt("DB_PORT") > 65535 {
|
||||
return fmt.Errorf("DB_PORT must be a valid port number")
|
||||
@ -278,6 +303,27 @@ func (c *Config) GetDatabaseDSN() string {
|
||||
)
|
||||
}
|
||||
|
||||
// GetDatabaseDSNForLogging returns a sanitized database connection string safe for logging
|
||||
func (c *Config) GetDatabaseDSNForLogging() string {
|
||||
password := c.GetString("DB_PASSWORD")
|
||||
maskedPassword := "***MASKED***"
|
||||
if len(password) > 0 {
|
||||
// Show first and last character with masking for debugging
|
||||
if len(password) >= 4 {
|
||||
maskedPassword = string(password[0]) + "***" + string(password[len(password)-1])
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
c.GetString("DB_HOST"),
|
||||
c.GetInt("DB_PORT"),
|
||||
c.GetString("DB_USER"),
|
||||
maskedPassword,
|
||||
c.GetString("DB_NAME"),
|
||||
c.GetString("DB_SSLMODE"),
|
||||
)
|
||||
}
|
||||
|
||||
// GetServerAddress returns the server address in host:port format
|
||||
func (c *Config) GetServerAddress() string {
|
||||
return fmt.Sprintf("%s:%d", c.GetString("SERVER_HOST"), c.GetInt("SERVER_PORT"))
|
||||
|
||||
@ -18,17 +18,42 @@ const (
|
||||
TokenLength = 32
|
||||
// TokenPrefix is prepended to all tokens for identification
|
||||
TokenPrefix = "kms_"
|
||||
// BcryptCost defines the bcrypt cost for 2025 security standards (minimum 14)
|
||||
BcryptCost = 14
|
||||
)
|
||||
|
||||
// TokenGenerator provides secure token generation and validation
|
||||
type TokenGenerator struct {
|
||||
hmacKey []byte
|
||||
hmacKey []byte
|
||||
bcryptCost int
|
||||
}
|
||||
|
||||
// NewTokenGenerator creates a new token generator with the provided HMAC key
|
||||
func NewTokenGenerator(hmacKey string) *TokenGenerator {
|
||||
return &TokenGenerator{
|
||||
hmacKey: []byte(hmacKey),
|
||||
hmacKey: []byte(hmacKey),
|
||||
bcryptCost: BcryptCost,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTokenGeneratorWithCost creates a new token generator with custom bcrypt cost
|
||||
func NewTokenGeneratorWithCost(hmacKey string, bcryptCost int) *TokenGenerator {
|
||||
// Validate bcrypt cost (must be between 4 and 31)
|
||||
if bcryptCost < 4 {
|
||||
bcryptCost = 4
|
||||
} else if bcryptCost > 31 {
|
||||
bcryptCost = 31
|
||||
}
|
||||
|
||||
// Warn if cost is too low for production
|
||||
if bcryptCost < 12 {
|
||||
// This should log a warning, but we don't have logger here
|
||||
// In a real implementation, you'd pass a logger or use a global one
|
||||
}
|
||||
|
||||
return &TokenGenerator{
|
||||
hmacKey: []byte(hmacKey),
|
||||
bcryptCost: bcryptCost,
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,10 +94,10 @@ func (tg *TokenGenerator) GenerateSecureTokenWithPrefix(appPrefix string, tokenT
|
||||
|
||||
// HashToken creates a secure hash of the token for storage
|
||||
func (tg *TokenGenerator) HashToken(token string) (string, error) {
|
||||
// Use bcrypt for secure password-like hashing
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
|
||||
// Use bcrypt with configured cost
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(token), tg.bcryptCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to hash token: %w", err)
|
||||
return "", fmt.Errorf("failed to hash token with bcrypt cost %d: %w", tg.bcryptCost, err)
|
||||
}
|
||||
|
||||
return string(hash), nil
|
||||
|
||||
@ -4,12 +4,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
@ -17,7 +13,8 @@ import (
|
||||
|
||||
// PostgresProvider implements the DatabaseProvider interface
|
||||
type PostgresProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dsn string
|
||||
}
|
||||
|
||||
// NewPostgresProvider creates a new PostgreSQL database provider
|
||||
@ -44,7 +41,7 @@ func NewPostgresProvider(dsn string, maxOpenConns, maxIdleConns int, maxLifetime
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &PostgresProvider{db: db}, nil
|
||||
return &PostgresProvider{db: db, dsn: dsn}, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying database connection
|
||||
@ -81,51 +78,7 @@ func (p *PostgresProvider) BeginTx(ctx context.Context) (repository.TransactionP
|
||||
return &PostgresTransaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
// Migrate runs database migrations
|
||||
func (p *PostgresProvider) Migrate(ctx context.Context, migrationPath string) error {
|
||||
// Create a separate connection for migrations to avoid interfering with the main connection
|
||||
migrationDB, err := sql.Open("postgres", p.getDSN())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open migration database connection: %w", err)
|
||||
}
|
||||
defer migrationDB.Close()
|
||||
|
||||
driver, err := postgres.WithInstance(migrationDB, &postgres.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create postgres driver: %w", err)
|
||||
}
|
||||
|
||||
// Convert relative path to file URL
|
||||
absPath, err := filepath.Abs(migrationPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
fmt.Sprintf("file://%s", absPath),
|
||||
"postgres",
|
||||
driver,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrate instance: %w", err)
|
||||
}
|
||||
defer m.Close()
|
||||
|
||||
// Run migrations
|
||||
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDSN reconstructs the DSN from the current connection
|
||||
// This is a workaround since we don't store the original DSN
|
||||
func (p *PostgresProvider) getDSN() string {
|
||||
// For now, we'll use the default values from config
|
||||
// In a production system, we'd store the original DSN
|
||||
return "host=localhost port=5432 user=postgres password=postgres dbname=kms sslmode=disable"
|
||||
}
|
||||
|
||||
// PostgresTransaction implements the TransactionProvider interface
|
||||
type PostgresTransaction struct {
|
||||
|
||||
245
internal/errors/secure_responses.go
Normal file
245
internal/errors/secure_responses.go
Normal file
@ -0,0 +1,245 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SecureErrorResponse represents a sanitized error response for clients
|
||||
type SecureErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
// ErrorHandler provides secure error handling for HTTP responses
|
||||
type ErrorHandler struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewErrorHandler creates a new secure error handler
|
||||
func NewErrorHandler(logger *zap.Logger) *ErrorHandler {
|
||||
return &ErrorHandler{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleError handles errors securely by logging detailed information and returning sanitized responses
|
||||
func (eh *ErrorHandler) HandleError(c *gin.Context, err error, userMessage string) {
|
||||
requestID := eh.getOrGenerateRequestID(c)
|
||||
|
||||
// Log detailed error information for internal debugging
|
||||
eh.logger.Error("HTTP request error",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("user_agent", c.Request.UserAgent()),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
// Determine appropriate HTTP status code and error type
|
||||
statusCode, errorType := eh.determineErrorResponse(err)
|
||||
|
||||
// Create sanitized response
|
||||
response := SecureErrorResponse{
|
||||
Error: errorType,
|
||||
Message: eh.sanitizeErrorMessage(userMessage, err),
|
||||
RequestID: requestID,
|
||||
Code: statusCode,
|
||||
}
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// HandleValidationError handles input validation errors
|
||||
func (eh *ErrorHandler) HandleValidationError(c *gin.Context, field string, message string) {
|
||||
requestID := eh.getOrGenerateRequestID(c)
|
||||
|
||||
eh.logger.Warn("Validation error",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("field", field),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
)
|
||||
|
||||
response := SecureErrorResponse{
|
||||
Error: "validation_error",
|
||||
Message: "Invalid input provided",
|
||||
RequestID: requestID,
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// HandleAuthenticationError handles authentication failures
|
||||
func (eh *ErrorHandler) HandleAuthenticationError(c *gin.Context, err error) {
|
||||
requestID := eh.getOrGenerateRequestID(c)
|
||||
|
||||
eh.logger.Warn("Authentication error",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
response := SecureErrorResponse{
|
||||
Error: "authentication_failed",
|
||||
Message: "Authentication required",
|
||||
RequestID: requestID,
|
||||
Code: http.StatusUnauthorized,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, response)
|
||||
}
|
||||
|
||||
// HandleAuthorizationError handles authorization failures
|
||||
func (eh *ErrorHandler) HandleAuthorizationError(c *gin.Context, resource string) {
|
||||
requestID := eh.getOrGenerateRequestID(c)
|
||||
|
||||
eh.logger.Warn("Authorization error",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("resource", resource),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
)
|
||||
|
||||
response := SecureErrorResponse{
|
||||
Error: "access_denied",
|
||||
Message: "Insufficient permissions",
|
||||
RequestID: requestID,
|
||||
Code: http.StatusForbidden,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusForbidden, response)
|
||||
}
|
||||
|
||||
// HandleInternalError handles internal server errors
|
||||
func (eh *ErrorHandler) HandleInternalError(c *gin.Context, err error) {
|
||||
requestID := eh.getOrGenerateRequestID(c)
|
||||
|
||||
eh.logger.Error("Internal server error",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
response := SecureErrorResponse{
|
||||
Error: "internal_error",
|
||||
Message: "An internal error occurred",
|
||||
RequestID: requestID,
|
||||
Code: http.StatusInternalServerError,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// determineErrorResponse determines the appropriate HTTP status and error type
|
||||
func (eh *ErrorHandler) determineErrorResponse(err error) (int, string) {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return appErr.StatusCode, eh.getErrorTypeFromCode(appErr.Code)
|
||||
}
|
||||
|
||||
// For unknown errors, log as internal error but don't expose details
|
||||
return http.StatusInternalServerError, "internal_error"
|
||||
}
|
||||
|
||||
// sanitizeErrorMessage removes sensitive information from error messages
|
||||
func (eh *ErrorHandler) sanitizeErrorMessage(userMessage string, err error) string {
|
||||
if userMessage != "" {
|
||||
return userMessage
|
||||
}
|
||||
|
||||
// Provide generic messages for different error types
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return eh.getGenericMessageFromCode(appErr.Code)
|
||||
}
|
||||
|
||||
return "An error occurred"
|
||||
}
|
||||
|
||||
// getErrorTypeFromCode converts an error code to a sanitized error type string
|
||||
func (eh *ErrorHandler) getErrorTypeFromCode(code ErrorCode) string {
|
||||
switch code {
|
||||
case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat:
|
||||
return "validation_error"
|
||||
case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials:
|
||||
return "authentication_failed"
|
||||
case ErrForbidden, ErrInsufficientPermissions:
|
||||
return "access_denied"
|
||||
case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound:
|
||||
return "resource_not_found"
|
||||
case ErrAlreadyExists, ErrConflict:
|
||||
return "resource_conflict"
|
||||
case ErrRateLimit:
|
||||
return "rate_limit_exceeded"
|
||||
case ErrTimeout:
|
||||
return "timeout"
|
||||
default:
|
||||
return "internal_error"
|
||||
}
|
||||
}
|
||||
|
||||
// getGenericMessageFromCode provides generic user-safe messages for error codes
|
||||
func (eh *ErrorHandler) getGenericMessageFromCode(code ErrorCode) string {
|
||||
switch code {
|
||||
case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat:
|
||||
return "Invalid input provided"
|
||||
case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials:
|
||||
return "Authentication required"
|
||||
case ErrForbidden, ErrInsufficientPermissions:
|
||||
return "Access denied"
|
||||
case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound:
|
||||
return "Resource not found"
|
||||
case ErrAlreadyExists, ErrConflict:
|
||||
return "Resource conflict"
|
||||
case ErrRateLimit:
|
||||
return "Rate limit exceeded"
|
||||
case ErrTimeout:
|
||||
return "Request timeout"
|
||||
default:
|
||||
return "An error occurred"
|
||||
}
|
||||
}
|
||||
|
||||
// getOrGenerateRequestID gets or generates a request ID for tracking
|
||||
func (eh *ErrorHandler) getOrGenerateRequestID(c *gin.Context) string {
|
||||
// Try to get existing request ID from context
|
||||
if requestID, exists := c.Get("request_id"); exists {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
// Try to get from header
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
|
||||
// Generate a simple request ID (in production, use a proper UUID library)
|
||||
return generateSimpleID()
|
||||
}
|
||||
|
||||
// generateSimpleID generates a simple request ID
|
||||
func generateSimpleID() string {
|
||||
// Simple implementation - in production use proper UUID generation
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
// Fallback to timestamp-based ID
|
||||
return fmt.Sprintf("req_%d", time.Now().UnixNano())
|
||||
}
|
||||
return "req_" + hex.EncodeToString(bytes)
|
||||
}
|
||||
@ -7,15 +7,21 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/authorization"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
"github.com/kms/api-key-service/internal/validation"
|
||||
)
|
||||
|
||||
// ApplicationHandler handles application-related HTTP requests
|
||||
type ApplicationHandler struct {
|
||||
appService services.ApplicationService
|
||||
authService services.AuthenticationService
|
||||
logger *zap.Logger
|
||||
appService services.ApplicationService
|
||||
authService services.AuthenticationService
|
||||
authzService *authorization.AuthorizationService
|
||||
validator *validation.Validator
|
||||
errorHandler *errors.ErrorHandler
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewApplicationHandler creates a new application handler
|
||||
@ -25,9 +31,12 @@ func NewApplicationHandler(
|
||||
logger *zap.Logger,
|
||||
) *ApplicationHandler {
|
||||
return &ApplicationHandler{
|
||||
appService: appService,
|
||||
authService: authService,
|
||||
logger: logger,
|
||||
appService: appService,
|
||||
authService: authService,
|
||||
authzService: authorization.NewAuthorizationService(logger),
|
||||
validator: validation.NewValidator(logger),
|
||||
errorHandler: errors.NewErrorHandler(logger),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,57 +44,99 @@ func NewApplicationHandler(
|
||||
func (h *ApplicationHandler) Create(c *gin.Context) {
|
||||
var req domain.CreateApplicationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Warn("Invalid request body", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Bad Request",
|
||||
"message": "Invalid request body: " + err.Error(),
|
||||
})
|
||||
h.errorHandler.HandleValidationError(c, "request_body", "Invalid application request format")
|
||||
return
|
||||
}
|
||||
|
||||
// Get user ID from context
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
h.logger.Error("User ID not found in context")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Internal Server Error",
|
||||
"message": "Authentication context not found",
|
||||
})
|
||||
// Get user ID from authenticated context
|
||||
userID := h.getUserIDFromContext(c)
|
||||
if userID == "" {
|
||||
h.errorHandler.HandleAuthenticationError(c, errors.NewUnauthorizedError("User authentication required"))
|
||||
return
|
||||
}
|
||||
|
||||
app, err := h.appService.Create(c.Request.Context(), &req, userID.(string))
|
||||
// Validate input
|
||||
validationErrors := h.validator.ValidateApplicationRequest(req.AppID, req.AppLink, req.CallbackURL, []string{})
|
||||
if len(validationErrors) > 0 {
|
||||
h.logger.Warn("Application validation failed",
|
||||
zap.String("user_id", userID),
|
||||
zap.Any("errors", validationErrors))
|
||||
h.errorHandler.HandleValidationError(c, "validation", "Invalid application data")
|
||||
return
|
||||
}
|
||||
|
||||
// Check authorization for creating applications
|
||||
authCtx := &authorization.AuthorizationContext{
|
||||
UserID: userID,
|
||||
ResourceType: authorization.ResourceTypeApplication,
|
||||
Action: authorization.ActionCreate,
|
||||
}
|
||||
|
||||
if err := h.authzService.AuthorizeResourceAccess(c.Request.Context(), authCtx); err != nil {
|
||||
h.errorHandler.HandleAuthorizationError(c, "application creation")
|
||||
return
|
||||
}
|
||||
|
||||
// Create the application
|
||||
app, err := h.appService.Create(c.Request.Context(), &req, userID)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create application", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Internal Server Error",
|
||||
"message": "Failed to create application",
|
||||
})
|
||||
h.errorHandler.HandleInternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("Application created", zap.String("app_id", app.AppID))
|
||||
h.logger.Info("Application created successfully",
|
||||
zap.String("app_id", app.AppID),
|
||||
zap.String("user_id", userID))
|
||||
|
||||
c.JSON(http.StatusCreated, app)
|
||||
}
|
||||
|
||||
// getUserIDFromContext extracts user ID from Gin context
|
||||
func (h *ApplicationHandler) getUserIDFromContext(c *gin.Context) string {
|
||||
// Try to get from Gin context first (set by middleware)
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to header (for compatibility)
|
||||
userEmail := c.GetHeader("X-User-Email")
|
||||
if userEmail != "" {
|
||||
return userEmail
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetByID handles GET /applications/:id
|
||||
func (h *ApplicationHandler) GetByID(c *gin.Context) {
|
||||
appID := c.Param("id")
|
||||
if appID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Bad Request",
|
||||
"message": "Application ID is required",
|
||||
})
|
||||
|
||||
// Get user ID from context
|
||||
userID := h.getUserIDFromContext(c)
|
||||
if userID == "" {
|
||||
h.errorHandler.HandleAuthenticationError(c, errors.NewUnauthorizedError("User authentication required"))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate app ID
|
||||
if result := h.validator.ValidateAppID(appID); !result.Valid {
|
||||
h.errorHandler.HandleValidationError(c, "app_id", "Invalid application ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get the application first
|
||||
app, err := h.appService.GetByID(c.Request.Context(), appID)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", appID))
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "Not Found",
|
||||
"message": "Application not found",
|
||||
})
|
||||
h.errorHandler.HandleError(c, err, "Application not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Check authorization for reading this application
|
||||
if err := h.authzService.AuthorizeApplicationOwnership(userID, app); err != nil {
|
||||
h.errorHandler.HandleAuthorizationError(c, "application access")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -1,32 +1,48 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related HTTP requests
|
||||
type AuthHandler struct {
|
||||
authService services.AuthenticationService
|
||||
tokenService services.TokenService
|
||||
logger *zap.Logger
|
||||
authService services.AuthenticationService
|
||||
tokenService services.TokenService
|
||||
headerValidator *auth.HeaderValidator
|
||||
config config.ConfigProvider
|
||||
errorHandler *errors.ErrorHandler
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler
|
||||
func NewAuthHandler(
|
||||
authService services.AuthenticationService,
|
||||
tokenService services.TokenService,
|
||||
config config.ConfigProvider,
|
||||
logger *zap.Logger,
|
||||
) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
tokenService: tokenService,
|
||||
logger: logger,
|
||||
authService: authService,
|
||||
tokenService: tokenService,
|
||||
headerValidator: auth.NewHeaderValidator(config, logger),
|
||||
config: config,
|
||||
errorHandler: errors.NewErrorHandler(logger),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -34,58 +50,81 @@ func NewAuthHandler(
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req domain.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Warn("Invalid login request", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Bad Request",
|
||||
"message": "Invalid request body: " + err.Error(),
|
||||
})
|
||||
h.errorHandler.HandleValidationError(c, "request_body", "Invalid login request format")
|
||||
return
|
||||
}
|
||||
|
||||
// For now, we'll extract user ID from headers since we're using HeaderAuthenticationProvider
|
||||
userID := c.GetHeader("X-User-Email")
|
||||
if userID == "" {
|
||||
h.logger.Warn("User email not found in headers")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized",
|
||||
"message": "User authentication required",
|
||||
})
|
||||
// Validate authentication headers with HMAC signature
|
||||
userContext, err := h.headerValidator.ValidateAuthenticationHeaders(c.Request)
|
||||
if err != nil {
|
||||
h.errorHandler.HandleAuthenticationError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("Processing login request", zap.String("user_id", userID), zap.String("app_id", req.AppID))
|
||||
h.logger.Info("Processing login request", zap.String("user_id", userContext.UserID), zap.String("app_id", req.AppID))
|
||||
|
||||
// Generate user token
|
||||
token, err := h.tokenService.GenerateUserToken(c.Request.Context(), req.AppID, userID, req.Permissions)
|
||||
token, err := h.tokenService.GenerateUserToken(c.Request.Context(), req.AppID, userContext.UserID, req.Permissions)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate user token", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Internal Server Error",
|
||||
"message": "Failed to generate token",
|
||||
})
|
||||
h.errorHandler.HandleInternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// For now, we'll just return the token directly
|
||||
// In a real implementation, this would redirect to the callback URL
|
||||
response := domain.LoginResponse{
|
||||
RedirectURL: req.RedirectURI + "?token=" + token,
|
||||
}
|
||||
|
||||
if req.RedirectURI == "" {
|
||||
// If no redirect URI, return token directly
|
||||
// If no redirect URI, return token directly via secure response body
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"user_id": userID,
|
||||
"user_id": userContext.UserID,
|
||||
"app_id": req.AppID,
|
||||
"expires_in": 604800, // 7 days in seconds
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For redirect flows, use secure cookie-based token delivery
|
||||
// Set secure cookie with the token
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(
|
||||
"auth_token", // name
|
||||
token, // value
|
||||
604800, // maxAge (7 days)
|
||||
"/", // path
|
||||
"", // domain (empty for current domain)
|
||||
true, // secure (HTTPS only)
|
||||
true, // httpOnly (no JavaScript access)
|
||||
)
|
||||
|
||||
// Generate a secure state parameter for CSRF protection
|
||||
state := h.generateSecureState(userContext.UserID, req.AppID)
|
||||
|
||||
// Redirect without token in URL
|
||||
response := domain.LoginResponse{
|
||||
RedirectURL: req.RedirectURI + "?state=" + state,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// generateSecureState generates a secure state parameter for OAuth flows
|
||||
func (h *AuthHandler) generateSecureState(userID, appID string) string {
|
||||
// Generate random bytes for state
|
||||
stateBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
h.logger.Error("Failed to generate random state", zap.Error(err))
|
||||
// Fallback to less secure but functional state
|
||||
return fmt.Sprintf("state_%s_%s_%d", userID, appID, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// Create HMAC signature to prevent tampering
|
||||
stateData := fmt.Sprintf("%s:%s:%x", userID, appID, stateBytes)
|
||||
mac := hmac.New(sha256.New, []byte(h.config.GetString("AUTH_SIGNING_KEY")))
|
||||
mac.Write([]byte(stateData))
|
||||
signature := hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
// Return base64-encoded state with signature
|
||||
return hex.EncodeToString([]byte(fmt.Sprintf("%s.%s", stateData, signature)))
|
||||
}
|
||||
|
||||
// Verify handles POST /verify
|
||||
func (h *AuthHandler) Verify(c *gin.Context) {
|
||||
var req domain.VerifyRequest
|
||||
|
||||
235
internal/middleware/csrf.go
Normal file
235
internal/middleware/csrf.go
Normal file
@ -0,0 +1,235 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
)
|
||||
|
||||
// CSRFMiddleware provides CSRF protection
|
||||
type CSRFMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCSRFMiddleware creates a new CSRF middleware
|
||||
func NewCSRFMiddleware(config config.ConfigProvider, logger *zap.Logger) *CSRFMiddleware {
|
||||
return &CSRFMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFProtection implements CSRF protection for state-changing operations
|
||||
func (cm *CSRFMiddleware) CSRFProtection(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip CSRF protection for safe methods
|
||||
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip CSRF protection for specific endpoints that use other authentication
|
||||
if cm.shouldSkipCSRF(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get CSRF token from header
|
||||
csrfToken := r.Header.Get("X-CSRF-Token")
|
||||
if csrfToken == "" {
|
||||
cm.logger.Warn("Missing CSRF token",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("remote_addr", r.RemoteAddr))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"csrf_token_missing","message":"CSRF token required"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF token
|
||||
if !cm.validateCSRFToken(csrfToken, r) {
|
||||
cm.logger.Warn("Invalid CSRF token",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("remote_addr", r.RemoteAddr))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"csrf_token_invalid","message":"Invalid CSRF token"}`))
|
||||
return
|
||||
}
|
||||
|
||||
cm.logger.Debug("CSRF token validated successfully",
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateCSRFToken generates a new CSRF token for a user session
|
||||
func (cm *CSRFMiddleware) GenerateCSRFToken(userID string) (string, error) {
|
||||
// Generate random bytes for token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
cm.logger.Error("Failed to generate CSRF token", zap.Error(err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create timestamp
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// Create token data
|
||||
tokenData := hex.EncodeToString(tokenBytes)
|
||||
|
||||
// Create signing string: userID:timestamp:tokenData
|
||||
timestampStr := strconv.FormatInt(timestamp, 10)
|
||||
signingString := userID + ":" + timestampStr + ":" + tokenData
|
||||
|
||||
// Sign the token with HMAC
|
||||
signature := cm.signData(signingString)
|
||||
|
||||
// Return encoded token: tokenData.timestamp.signature
|
||||
token := tokenData + "." + timestampStr + "." + signature
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// validateCSRFToken validates a CSRF token
|
||||
func (cm *CSRFMiddleware) validateCSRFToken(token string, r *http.Request) bool {
|
||||
// Parse token parts
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
cm.logger.Debug("Invalid CSRF token format")
|
||||
return false
|
||||
}
|
||||
|
||||
tokenData, timestampStr, signature := parts[0], parts[1], parts[2]
|
||||
|
||||
// Get user ID from request context or headers
|
||||
userID := cm.getUserIDFromRequest(r)
|
||||
if userID == "" {
|
||||
cm.logger.Debug("No user ID found for CSRF validation")
|
||||
return false
|
||||
}
|
||||
|
||||
// Recreate signing string
|
||||
signingString := userID + ":" + timestampStr + ":" + tokenData
|
||||
|
||||
// Verify signature
|
||||
expectedSignature := cm.signData(signingString)
|
||||
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
|
||||
cm.logger.Debug("CSRF token signature verification failed")
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse timestamp
|
||||
timestampInt, err := strconv.ParseInt(timestampStr, 10, 64)
|
||||
if err != nil {
|
||||
cm.logger.Debug("Invalid timestamp in CSRF token", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
timestamp := time.Unix(timestampInt, 0)
|
||||
|
||||
// Check if token is expired (valid for 1 hour by default)
|
||||
maxAge := cm.config.GetDuration("CSRF_TOKEN_MAX_AGE")
|
||||
if maxAge <= 0 {
|
||||
maxAge = 1 * time.Hour
|
||||
}
|
||||
|
||||
if time.Since(timestamp) > maxAge {
|
||||
cm.logger.Debug("CSRF token expired",
|
||||
zap.Time("timestamp", timestamp),
|
||||
zap.Duration("age", time.Since(timestamp)),
|
||||
zap.Duration("max_age", maxAge))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// signData signs data with HMAC
|
||||
func (cm *CSRFMiddleware) signData(data string) string {
|
||||
// Use the same signing key as for authentication
|
||||
signingKey := cm.config.GetString("AUTH_SIGNING_KEY")
|
||||
if signingKey == "" {
|
||||
cm.logger.Error("AUTH_SIGNING_KEY not configured for CSRF protection")
|
||||
return ""
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(signingKey))
|
||||
mac.Write([]byte(data))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// getUserIDFromRequest extracts user ID from request
|
||||
func (cm *CSRFMiddleware) getUserIDFromRequest(r *http.Request) string {
|
||||
// Try to get from X-User-Email header
|
||||
userEmail := r.Header.Get(cm.config.GetString("AUTH_HEADER_USER_EMAIL"))
|
||||
if userEmail != "" {
|
||||
return userEmail
|
||||
}
|
||||
|
||||
// Try to get from context (if set by authentication middleware)
|
||||
if userID := r.Context().Value("user_id"); userID != nil {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// shouldSkipCSRF determines if CSRF protection should be skipped for a request
|
||||
func (cm *CSRFMiddleware) shouldSkipCSRF(r *http.Request) bool {
|
||||
// Skip for API endpoints that use API key authentication
|
||||
if strings.HasPrefix(r.URL.Path, "/api/verify") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip for health check endpoints
|
||||
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip for webhook endpoints (if any)
|
||||
if strings.HasPrefix(r.URL.Path, "/webhook/") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetCSRFCookie sets a secure CSRF token cookie
|
||||
func (cm *CSRFMiddleware) SetCSRFCookie(w http.ResponseWriter, token string) {
|
||||
cookie := &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 3600, // 1 hour
|
||||
HttpOnly: false, // JavaScript needs to read this for AJAX requests
|
||||
Secure: true, // HTTPS only
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
// GetCSRFTokenFromCookie gets CSRF token from cookie
|
||||
func (cm *CSRFMiddleware) GetCSRFTokenFromCookie(r *http.Request) string {
|
||||
cookie, err := r.Cookie("csrf_token")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
@ -23,23 +23,25 @@ import (
|
||||
|
||||
// SecurityMiddleware provides various security features
|
||||
type SecurityMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
appRepo repository.ApplicationRepository
|
||||
rateLimiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
appRepo repository.ApplicationRepository
|
||||
rateLimiters map[string]*rate.Limiter
|
||||
authRateLimiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware
|
||||
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger, appRepo repository.ApplicationRepository) *SecurityMiddleware {
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
return &SecurityMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
appRepo: appRepo,
|
||||
rateLimiters: make(map[string]*rate.Limiter),
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
appRepo: appRepo,
|
||||
rateLimiters: make(map[string]*rate.Limiter),
|
||||
authRateLimiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,6 +78,38 @@ func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler
|
||||
})
|
||||
}
|
||||
|
||||
// AuthRateLimitMiddleware implements stricter rate limiting for authentication endpoints
|
||||
func (s *SecurityMiddleware) AuthRateLimitMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Use stricter rate limits for auth endpoints
|
||||
limiter := s.getAuthRateLimiter(clientIP)
|
||||
|
||||
// Check if request is allowed
|
||||
if !limiter.Allow() {
|
||||
s.logger.Warn("Auth rate limit exceeded",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
// Track authentication failures for brute force protection
|
||||
s.TrackAuthenticationFailure(clientIP, "")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"auth_rate_limit_exceeded","message":"Too many authentication attempts"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// BruteForceProtectionMiddleware implements brute force protection
|
||||
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -231,6 +265,35 @@ func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) getAuthRateLimiter(clientIP string) *rate.Limiter {
|
||||
s.mu.RLock()
|
||||
limiter, exists := s.authRateLimiters[clientIP]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Create new auth rate limiter with stricter limits
|
||||
authRPS := s.config.GetInt("AUTH_RATE_LIMIT_RPS")
|
||||
if authRPS <= 0 {
|
||||
authRPS = 5 // Very strict default for auth endpoints
|
||||
}
|
||||
|
||||
authBurst := s.config.GetInt("AUTH_RATE_LIMIT_BURST")
|
||||
if authBurst <= 0 {
|
||||
authBurst = 10 // Allow small bursts
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rate.Limit(authRPS), authBurst)
|
||||
|
||||
s.mu.Lock()
|
||||
s.authRateLimiters[clientIP] = limiter
|
||||
s.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("rate_limit_violations", clientIP)
|
||||
|
||||
@ -168,9 +168,6 @@ type DatabaseProvider interface {
|
||||
|
||||
// BeginTx starts a database transaction
|
||||
BeginTx(ctx context.Context) (TransactionProvider, error)
|
||||
|
||||
// Migrate runs database migrations
|
||||
Migrate(ctx context.Context, migrationPath string) error
|
||||
}
|
||||
|
||||
// TransactionProvider defines the interface for database transaction operations
|
||||
|
||||
@ -201,83 +201,118 @@ func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]
|
||||
|
||||
// Update updates an existing application
|
||||
func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) {
|
||||
// Build dynamic update query
|
||||
// Build secure dynamic update query using a whitelist approach
|
||||
var setParts []string
|
||||
var args []interface{}
|
||||
argIndex := 1
|
||||
|
||||
// Whitelist of allowed fields to prevent SQL injection
|
||||
allowedFields := map[string]string{
|
||||
"app_link": "app_link",
|
||||
"type": "type",
|
||||
"callback_url": "callback_url",
|
||||
"hmac_key": "hmac_key",
|
||||
"token_prefix": "token_prefix",
|
||||
"token_renewal_duration": "token_renewal_duration",
|
||||
"max_token_duration": "max_token_duration",
|
||||
"owner_type": "owner_type",
|
||||
"owner_name": "owner_name",
|
||||
"owner_owner": "owner_owner",
|
||||
}
|
||||
|
||||
if updates.AppLink != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("app_link = $%d", argIndex))
|
||||
args = append(args, *updates.AppLink)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["app_link"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.AppLink)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.Type != nil {
|
||||
typeStrings := make([]string, len(*updates.Type))
|
||||
for i, t := range *updates.Type {
|
||||
typeStrings[i] = string(t)
|
||||
if field, ok := allowedFields["type"]; ok {
|
||||
typeStrings := make([]string, len(*updates.Type))
|
||||
for i, t := range *updates.Type {
|
||||
typeStrings[i] = string(t)
|
||||
}
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, pq.Array(typeStrings))
|
||||
argIndex++
|
||||
}
|
||||
setParts = append(setParts, fmt.Sprintf("type = $%d", argIndex))
|
||||
args = append(args, pq.Array(typeStrings))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.CallbackURL != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("callback_url = $%d", argIndex))
|
||||
args = append(args, *updates.CallbackURL)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["callback_url"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.CallbackURL)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.HMACKey != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("hmac_key = $%d", argIndex))
|
||||
args = append(args, *updates.HMACKey)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["hmac_key"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.HMACKey)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.TokenPrefix != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("token_prefix = $%d", argIndex))
|
||||
args = append(args, *updates.TokenPrefix)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["token_prefix"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.TokenPrefix)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.TokenRenewalDuration != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("token_renewal_duration = $%d", argIndex))
|
||||
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
if field, ok := allowedFields["token_renewal_duration"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.MaxTokenDuration != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("max_token_duration = $%d", argIndex))
|
||||
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
if field, ok := allowedFields["max_token_duration"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.Owner != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("owner_type = $%d", argIndex))
|
||||
args = append(args, string(updates.Owner.Type))
|
||||
argIndex++
|
||||
if field, ok := allowedFields["owner_type"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, string(updates.Owner.Type))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
setParts = append(setParts, fmt.Sprintf("owner_name = $%d", argIndex))
|
||||
args = append(args, updates.Owner.Name)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["owner_name"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.Owner.Name)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
setParts = append(setParts, fmt.Sprintf("owner_owner = $%d", argIndex))
|
||||
args = append(args, updates.Owner.Owner)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["owner_owner"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.Owner.Owner)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if len(setParts) == 0 {
|
||||
return r.GetByID(ctx, appID) // No updates, return current state
|
||||
}
|
||||
|
||||
// Always update the updated_at field
|
||||
// Always update the updated_at field - using literal field name for security
|
||||
setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex))
|
||||
args = append(args, time.Now())
|
||||
argIndex++
|
||||
|
||||
// Add WHERE clause
|
||||
// Add WHERE clause parameter
|
||||
args = append(args, appID)
|
||||
|
||||
// Build the final query with properly parameterized placeholders
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE applications
|
||||
SET %s
|
||||
|
||||
375
internal/validation/validator.go
Normal file
375
internal/validation/validator.go
Normal file
@ -0,0 +1,375 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Validator provides comprehensive input validation
|
||||
type Validator struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewValidator creates a new input validator
|
||||
func NewValidator(logger *zap.Logger) *Validator {
|
||||
return &Validator{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidationError represents a validation error
|
||||
type ValidationError struct {
|
||||
Field string `json:"field"`
|
||||
Message string `json:"message"`
|
||||
Value string `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (e ValidationError) Error() string {
|
||||
return fmt.Sprintf("validation error for field '%s': %s", e.Field, e.Message)
|
||||
}
|
||||
|
||||
// ValidationResult holds the result of validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
Errors []ValidationError `json:"errors"`
|
||||
}
|
||||
|
||||
// AddError adds a validation error
|
||||
func (vr *ValidationResult) AddError(field, message, value string) {
|
||||
vr.Valid = false
|
||||
vr.Errors = append(vr.Errors, ValidationError{
|
||||
Field: field,
|
||||
Message: message,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
// Regular expressions for validation
|
||||
var (
|
||||
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
appIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$`)
|
||||
tokenPrefixRegex = regexp.MustCompile(`^[A-Z]{2,4}$`)
|
||||
permissionRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9._]*[a-zA-Z0-9]$`)
|
||||
)
|
||||
|
||||
// ValidateEmail validates email addresses
|
||||
func (v *Validator) ValidateEmail(email string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if email == "" {
|
||||
result.AddError("email", "Email is required", "")
|
||||
return result
|
||||
}
|
||||
|
||||
if len(email) > 254 {
|
||||
result.AddError("email", "Email too long (max 254 characters)", email)
|
||||
return result
|
||||
}
|
||||
|
||||
if !emailRegex.MatchString(email) {
|
||||
result.AddError("email", "Invalid email format", email)
|
||||
return result
|
||||
}
|
||||
|
||||
// Additional email security checks
|
||||
if strings.Contains(email, "..") {
|
||||
result.AddError("email", "Email contains consecutive dots", email)
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for potentially dangerous characters
|
||||
dangerousChars := []string{"<", ">", "\"", "'", "&", ";", "|", "`"}
|
||||
for _, char := range dangerousChars {
|
||||
if strings.Contains(email, char) {
|
||||
result.AddError("email", "Email contains invalid characters", email)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateAppID validates application IDs
|
||||
func (v *Validator) ValidateAppID(appID string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if appID == "" {
|
||||
result.AddError("app_id", "Application ID is required", "")
|
||||
return result
|
||||
}
|
||||
|
||||
if len(appID) < 3 || len(appID) > 100 {
|
||||
result.AddError("app_id", "Application ID must be between 3 and 100 characters", appID)
|
||||
return result
|
||||
}
|
||||
|
||||
if !appIDRegex.MatchString(appID) {
|
||||
result.AddError("app_id", "Application ID must start and end with alphanumeric characters and contain only letters, numbers, dots, hyphens, and underscores", appID)
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for reserved names
|
||||
reservedNames := []string{"admin", "root", "system", "internal", "api", "www", "mail", "ftp"}
|
||||
for _, reserved := range reservedNames {
|
||||
if strings.EqualFold(appID, reserved) {
|
||||
result.AddError("app_id", "Application ID cannot be a reserved name", appID)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateURL validates URLs
|
||||
func (v *Validator) ValidateURL(urlStr, fieldName string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if urlStr == "" {
|
||||
result.AddError(fieldName, "URL is required", "")
|
||||
return result
|
||||
}
|
||||
|
||||
if len(urlStr) > 2000 {
|
||||
result.AddError(fieldName, "URL too long (max 2000 characters)", urlStr)
|
||||
return result
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
result.AddError(fieldName, "Invalid URL format", urlStr)
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate scheme
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
result.AddError(fieldName, "URL must use http or https scheme", urlStr)
|
||||
return result
|
||||
}
|
||||
|
||||
// Security: Require HTTPS in production (configurable)
|
||||
if parsedURL.Scheme != "https" {
|
||||
v.logger.Warn("Non-HTTPS URL provided", zap.String("url", urlStr))
|
||||
// In strict mode, this would be an error
|
||||
// result.AddError(fieldName, "HTTPS is required", urlStr)
|
||||
}
|
||||
|
||||
// Validate host
|
||||
if parsedURL.Host == "" {
|
||||
result.AddError(fieldName, "URL must have a valid host", urlStr)
|
||||
return result
|
||||
}
|
||||
|
||||
// Security: Block localhost and private IPs in production
|
||||
if v.isPrivateOrLocalhost(parsedURL.Host) {
|
||||
result.AddError(fieldName, "URLs pointing to private or localhost addresses are not allowed", urlStr)
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidatePermissions validates a list of permissions
|
||||
func (v *Validator) ValidatePermissions(permissions []string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if len(permissions) == 0 {
|
||||
result.AddError("permissions", "At least one permission is required", "")
|
||||
return result
|
||||
}
|
||||
|
||||
if len(permissions) > 50 {
|
||||
result.AddError("permissions", "Too many permissions (max 50)", fmt.Sprintf("%d", len(permissions)))
|
||||
return result
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
for i, permission := range permissions {
|
||||
field := fmt.Sprintf("permissions[%d]", i)
|
||||
|
||||
// Check for duplicates
|
||||
if seen[permission] {
|
||||
result.AddError(field, "Duplicate permission", permission)
|
||||
continue
|
||||
}
|
||||
seen[permission] = true
|
||||
|
||||
// Validate individual permission
|
||||
if err := v.validateSinglePermission(permission); err != nil {
|
||||
result.AddError(field, err.Error(), permission)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateTokenPrefix validates token prefixes
|
||||
func (v *Validator) ValidateTokenPrefix(prefix string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if prefix == "" {
|
||||
// Empty prefix is allowed - will use default
|
||||
return result
|
||||
}
|
||||
|
||||
if len(prefix) < 2 || len(prefix) > 4 {
|
||||
result.AddError("token_prefix", "Token prefix must be between 2 and 4 characters", prefix)
|
||||
return result
|
||||
}
|
||||
|
||||
if !tokenPrefixRegex.MatchString(prefix) {
|
||||
result.AddError("token_prefix", "Token prefix must contain only uppercase letters", prefix)
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateString validates a general string with length and content constraints
|
||||
func (v *Validator) ValidateString(value, fieldName string, minLen, maxLen int, allowEmpty bool) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if value == "" && !allowEmpty {
|
||||
result.AddError(fieldName, fmt.Sprintf("%s is required", fieldName), "")
|
||||
return result
|
||||
}
|
||||
|
||||
if len(value) < minLen {
|
||||
result.AddError(fieldName, fmt.Sprintf("%s must be at least %d characters", fieldName, minLen), value)
|
||||
return result
|
||||
}
|
||||
|
||||
if len(value) > maxLen {
|
||||
result.AddError(fieldName, fmt.Sprintf("%s must be at most %d characters", fieldName, maxLen), value)
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for control characters and other potentially dangerous characters
|
||||
for i, r := range value {
|
||||
if unicode.IsControl(r) && r != '\n' && r != '\r' && r != '\t' {
|
||||
result.AddError(fieldName, fmt.Sprintf("%s contains invalid control character at position %d", fieldName, i), value)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check for null bytes
|
||||
if strings.Contains(value, "\x00") {
|
||||
result.AddError(fieldName, fmt.Sprintf("%s contains null bytes", fieldName), value)
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateDuration validates duration strings
|
||||
func (v *Validator) ValidateDuration(duration, fieldName string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if duration == "" {
|
||||
result.AddError(fieldName, "Duration is required", "")
|
||||
return result
|
||||
}
|
||||
|
||||
// Basic duration format validation (Go duration format)
|
||||
durationRegex := regexp.MustCompile(`^(\d+(\.\d+)?(ns|us|µs|ms|s|m|h))+$`)
|
||||
if !durationRegex.MatchString(duration) {
|
||||
result.AddError(fieldName, "Invalid duration format (use Go duration format like '1h', '30m', '5s')", duration)
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (v *Validator) validateSinglePermission(permission string) error {
|
||||
if permission == "" {
|
||||
return fmt.Errorf("permission cannot be empty")
|
||||
}
|
||||
|
||||
if len(permission) > 100 {
|
||||
return fmt.Errorf("permission too long (max 100 characters)")
|
||||
}
|
||||
|
||||
if !permissionRegex.MatchString(permission) {
|
||||
return fmt.Errorf("permission must start and end with alphanumeric characters and contain only letters, numbers, dots, and underscores")
|
||||
}
|
||||
|
||||
// Validate permission hierarchy (dots separate levels)
|
||||
parts := strings.Split(permission, ".")
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
return fmt.Errorf("permission level %d is empty", i+1)
|
||||
}
|
||||
if len(part) > 50 {
|
||||
return fmt.Errorf("permission level %d is too long (max 50 characters)", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) > 5 {
|
||||
return fmt.Errorf("permission hierarchy too deep (max 5 levels)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Validator) isPrivateOrLocalhost(host string) bool {
|
||||
// Remove port if present
|
||||
if colonIndex := strings.LastIndex(host, ":"); colonIndex != -1 {
|
||||
host = host[:colonIndex]
|
||||
}
|
||||
|
||||
// Check for localhost variants
|
||||
localhosts := []string{"localhost", "127.0.0.1", "::1", "0.0.0.0"}
|
||||
for _, localhost := range localhosts {
|
||||
if strings.EqualFold(host, localhost) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (simplified)
|
||||
privateRanges := []string{
|
||||
"10.", "192.168.", "172.16.", "172.17.", "172.18.", "172.19.",
|
||||
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.", "172.25.",
|
||||
"172.26.", "172.27.", "172.28.", "172.29.", "172.30.", "172.31.",
|
||||
}
|
||||
|
||||
for _, privateRange := range privateRanges {
|
||||
if strings.HasPrefix(host, privateRange) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateApplicationRequest validates create/update application requests
|
||||
func (v *Validator) ValidateApplicationRequest(appID, appLink, callbackURL string, permissions []string) []ValidationError {
|
||||
var errors []ValidationError
|
||||
|
||||
// Validate app ID
|
||||
if result := v.ValidateAppID(appID); !result.Valid {
|
||||
errors = append(errors, result.Errors...)
|
||||
}
|
||||
|
||||
// Validate app link URL
|
||||
if result := v.ValidateURL(appLink, "app_link"); !result.Valid {
|
||||
errors = append(errors, result.Errors...)
|
||||
}
|
||||
|
||||
// Validate callback URL
|
||||
if result := v.ValidateURL(callbackURL, "callback_url"); !result.Valid {
|
||||
errors = append(errors, result.Errors...)
|
||||
}
|
||||
|
||||
// Validate permissions
|
||||
if result := v.ValidatePermissions(permissions); !result.Valid {
|
||||
errors = append(errors, result.Errors...)
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
Reference in New Issue
Block a user