diff --git a/docs/PRODUCTION_ROADMAP.md b/docs/PRODUCTION_ROADMAP.md index 54cab77..54a5165 100644 --- a/docs/PRODUCTION_ROADMAP.md +++ b/docs/PRODUCTION_ROADMAP.md @@ -64,10 +64,10 @@ This document outlines the complete roadmap for making the API Key Management Se - [x] Implement authorization code exchange and token refresh - [x] Add user info retrieval from OAuth2 providers - [x] Create comprehensive OAuth2 unit tests with benchmarks -- [ ] Add SAML authentication support -- [ ] Create user session management +- [x] Add SAML authentication support +- [x] Create user session management - [x] Implement role-based access control (RBAC) -- [ ] Add multi-tenant authentication support +- [x] Add multi-tenant authentication support ### Permission System Enhancement - [x] Implement hierarchical permission inheritance @@ -120,7 +120,7 @@ This document outlines the complete roadmap for making the API Key Management Se - [x] Create authentication failure tracking ### Audit & Compliance -- [ ] Implement comprehensive audit logging +- [x] Implement comprehensive audit logging - [ ] Add compliance reporting features - [ ] Create data retention policies - [ ] Implement GDPR compliance features diff --git a/go.mod b/go.mod index da5ae31..87d86ea 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( ) require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect @@ -33,6 +34,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/jmoiron/sqlx v1.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect diff --git a/go.sum b/go.sum index 2555ccc..25c0ee3 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,8 @@ +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -43,6 +46,7 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -65,10 +69,13 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -78,6 +85,7 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/internal/audit/audit.go b/internal/audit/audit.go new file mode 100644 index 0000000..106ba39 --- /dev/null +++ b/internal/audit/audit.go @@ -0,0 +1,590 @@ +package audit + +import ( + "context" + "encoding/json" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + + "github.com/kms/api-key-service/internal/config" +) + +// EventType represents the type of audit event +type EventType string + +const ( + // Authentication events + EventTypeLogin EventType = "auth.login" + EventTypeLoginFailed EventType = "auth.login_failed" + EventTypeLogout EventType = "auth.logout" + EventTypeTokenCreated EventType = "auth.token_created" + EventTypeTokenRevoked EventType = "auth.token_revoked" + EventTypeTokenValidated EventType = "auth.token_validated" + + // Session events + EventTypeSessionCreated EventType = "session.created" + EventTypeSessionRevoked EventType = "session.revoked" + EventTypeSessionExpired EventType = "session.expired" + + // Application events + EventTypeAppCreated EventType = "app.created" + EventTypeAppUpdated EventType = "app.updated" + EventTypeAppDeleted EventType = "app.deleted" + + // Permission events + EventTypePermissionGranted EventType = "permission.granted" + EventTypePermissionRevoked EventType = "permission.revoked" + EventTypePermissionDenied EventType = "permission.denied" + + // Tenant events + EventTypeTenantCreated EventType = "tenant.created" + EventTypeTenantUpdated EventType = "tenant.updated" + EventTypeTenantSuspended EventType = "tenant.suspended" + EventTypeTenantActivated EventType = "tenant.activated" + + // User events + EventTypeUserCreated EventType = "user.created" + EventTypeUserUpdated EventType = "user.updated" + EventTypeUserSuspended EventType = "user.suspended" + EventTypeUserActivated EventType = "user.activated" + + // Security events + EventTypeSecurityViolation EventType = "security.violation" + EventTypeBruteForceAttempt EventType = "security.brute_force" + EventTypeIPBlocked EventType = "security.ip_blocked" + EventTypeRateLimitExceeded EventType = "security.rate_limit_exceeded" + + // System events + EventTypeSystemStartup EventType = "system.startup" + EventTypeSystemShutdown EventType = "system.shutdown" + EventTypeConfigChanged EventType = "system.config_changed" +) + +// EventSeverity represents the severity level of an audit event +type EventSeverity string + +const ( + SeverityInfo EventSeverity = "info" + SeverityWarning EventSeverity = "warning" + SeverityError EventSeverity = "error" + SeverityCritical EventSeverity = "critical" +) + +// EventStatus represents the status of an audit event +type EventStatus string + +const ( + StatusSuccess EventStatus = "success" + StatusFailure EventStatus = "failure" + StatusPending EventStatus = "pending" +) + +// AuditEvent represents a single audit event +type AuditEvent struct { + ID uuid.UUID `json:"id" db:"id"` + Type EventType `json:"type" db:"type"` + Severity EventSeverity `json:"severity" db:"severity"` + Status EventStatus `json:"status" db:"status"` + Timestamp time.Time `json:"timestamp" db:"timestamp"` + + // Actor information + ActorID string `json:"actor_id,omitempty" db:"actor_id"` + ActorType string `json:"actor_type,omitempty" db:"actor_type"` // user, system, service + ActorIP string `json:"actor_ip,omitempty" db:"actor_ip"` + UserAgent string `json:"user_agent,omitempty" db:"user_agent"` + + // Tenant information + TenantID *uuid.UUID `json:"tenant_id,omitempty" db:"tenant_id"` + + // Resource information + ResourceID string `json:"resource_id,omitempty" db:"resource_id"` + ResourceType string `json:"resource_type,omitempty" db:"resource_type"` + + // Event details + Action string `json:"action" db:"action"` + Description string `json:"description" db:"description"` + Details map[string]interface{} `json:"details,omitempty" db:"details"` + + // Request context + RequestID string `json:"request_id,omitempty" db:"request_id"` + SessionID string `json:"session_id,omitempty" db:"session_id"` + + // Additional metadata + Tags []string `json:"tags,omitempty" db:"tags"` + Metadata map[string]string `json:"metadata,omitempty" db:"metadata"` +} + +// AuditLogger defines the interface for audit logging +type AuditLogger interface { + // LogEvent logs a single audit event + LogEvent(ctx context.Context, event *AuditEvent) error + + // LogAuthEvent logs an authentication-related event + LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error + + // LogPermissionEvent logs a permission-related event + LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error + + // LogSecurityEvent logs a security-related event + LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error + + // LogSystemEvent logs a system-related event + LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error + + // QueryEvents queries audit events with filters + QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error) + + // GetEventStats returns audit event statistics + GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error) +} + +// AuditFilter represents filters for querying audit events +type AuditFilter struct { + EventTypes []EventType `json:"event_types,omitempty"` + Severities []EventSeverity `json:"severities,omitempty"` + Statuses []EventStatus `json:"statuses,omitempty"` + ActorID string `json:"actor_id,omitempty"` + ActorType string `json:"actor_type,omitempty"` + TenantID *uuid.UUID `json:"tenant_id,omitempty"` + ResourceID string `json:"resource_id,omitempty"` + ResourceType string `json:"resource_type,omitempty"` + StartTime *time.Time `json:"start_time,omitempty"` + EndTime *time.Time `json:"end_time,omitempty"` + Tags []string `json:"tags,omitempty"` + Limit int `json:"limit"` + Offset int `json:"offset"` + OrderBy string `json:"order_by"` // timestamp, type, severity + OrderDesc bool `json:"order_desc"` +} + +// AuditStatsFilter represents filters for audit statistics +type AuditStatsFilter struct { + EventTypes []EventType `json:"event_types,omitempty"` + TenantID *uuid.UUID `json:"tenant_id,omitempty"` + StartTime *time.Time `json:"start_time,omitempty"` + EndTime *time.Time `json:"end_time,omitempty"` + GroupBy string `json:"group_by"` // type, severity, status, hour, day +} + +// AuditStats represents audit event statistics +type AuditStats struct { + TotalEvents int `json:"total_events"` + ByType map[EventType]int `json:"by_type"` + BySeverity map[EventSeverity]int `json:"by_severity"` + ByStatus map[EventStatus]int `json:"by_status"` + ByTime map[string]int `json:"by_time,omitempty"` +} + +// auditLogger implements the AuditLogger interface +type auditLogger struct { + config config.ConfigProvider + logger *zap.Logger + repository AuditRepository +} + +// AuditRepository defines the interface for audit event storage +type AuditRepository interface { + Create(ctx context.Context, event *AuditEvent) error + Query(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error) + GetStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error) + DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) +} + +// NewAuditLogger creates a new audit logger +func NewAuditLogger(config config.ConfigProvider, logger *zap.Logger, repository AuditRepository) AuditLogger { + return &auditLogger{ + config: config, + logger: logger, + repository: repository, + } +} + +// LogEvent logs a single audit event +func (a *auditLogger) LogEvent(ctx context.Context, event *AuditEvent) error { + // Set default values + if event.ID == uuid.Nil { + event.ID = uuid.New() + } + if event.Timestamp.IsZero() { + event.Timestamp = time.Now().UTC() + } + if event.Severity == "" { + event.Severity = SeverityInfo + } + if event.Status == "" { + event.Status = StatusSuccess + } + + // Extract request context if available + if requestID := ctx.Value("request_id"); requestID != nil { + if reqID, ok := requestID.(string); ok { + event.RequestID = reqID + } + } + + // Log to structured logger + a.logToStructuredLogger(event) + + // Store in repository + if err := a.repository.Create(ctx, event); err != nil { + a.logger.Error("Failed to store audit event", + zap.Error(err), + zap.String("event_id", event.ID.String()), + zap.String("event_type", string(event.Type))) + return err + } + + return nil +} + +// LogAuthEvent logs an authentication-related event +func (a *auditLogger) LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error { + severity := SeverityInfo + status := StatusSuccess + + // Determine severity and status based on event type + switch eventType { + case EventTypeLoginFailed: + severity = SeverityWarning + status = StatusFailure + case EventTypeTokenRevoked: + severity = SeverityWarning + } + + event := &AuditEvent{ + Type: eventType, + Severity: severity, + Status: status, + ActorID: actorID, + ActorType: "user", + ActorIP: actorIP, + Action: string(eventType), + Description: a.generateDescription(eventType, details), + Details: details, + Tags: []string{"authentication"}, + } + + return a.LogEvent(ctx, event) +} + +// LogPermissionEvent logs a permission-related event +func (a *auditLogger) LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error { + severity := SeverityInfo + status := StatusSuccess + + // Determine severity and status based on event type + switch eventType { + case EventTypePermissionDenied: + severity = SeverityWarning + status = StatusFailure + } + + event := &AuditEvent{ + Type: eventType, + Severity: severity, + Status: status, + ActorID: actorID, + ActorType: "user", + ResourceID: resourceID, + ResourceType: resourceType, + Action: string(eventType), + Description: a.generateDescription(eventType, details), + Details: details, + Tags: []string{"permission", "authorization"}, + } + + return a.LogEvent(ctx, event) +} + +// LogSecurityEvent logs a security-related event +func (a *auditLogger) LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error { + status := StatusSuccess + if severity == SeverityError || severity == SeverityCritical { + status = StatusFailure + } + + event := &AuditEvent{ + Type: eventType, + Severity: severity, + Status: status, + ActorIP: actorIP, + ActorType: "system", + Action: string(eventType), + Description: a.generateDescription(eventType, details), + Details: details, + Tags: []string{"security"}, + } + + return a.LogEvent(ctx, event) +} + +// LogSystemEvent logs a system-related event +func (a *auditLogger) LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error { + event := &AuditEvent{ + Type: eventType, + Severity: SeverityInfo, + Status: StatusSuccess, + ActorType: "system", + Action: string(eventType), + Description: a.generateDescription(eventType, details), + Details: details, + Tags: []string{"system"}, + } + + return a.LogEvent(ctx, event) +} + +// QueryEvents queries audit events with filters +func (a *auditLogger) QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error) { + // Set default pagination + if filter.Limit <= 0 { + filter.Limit = 100 + } + if filter.Limit > 1000 { + filter.Limit = 1000 + } + if filter.OrderBy == "" { + filter.OrderBy = "timestamp" + filter.OrderDesc = true + } + + return a.repository.Query(ctx, filter) +} + +// GetEventStats returns audit event statistics +func (a *auditLogger) GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error) { + return a.repository.GetStats(ctx, filter) +} + +// logToStructuredLogger logs the event to the structured logger +func (a *auditLogger) logToStructuredLogger(event *AuditEvent) { + fields := []zap.Field{ + zap.String("audit_event_id", event.ID.String()), + zap.String("event_type", string(event.Type)), + zap.String("severity", string(event.Severity)), + zap.String("status", string(event.Status)), + zap.Time("timestamp", event.Timestamp), + zap.String("action", event.Action), + zap.String("description", event.Description), + } + + if event.ActorID != "" { + fields = append(fields, zap.String("actor_id", event.ActorID)) + } + if event.ActorType != "" { + fields = append(fields, zap.String("actor_type", event.ActorType)) + } + if event.ActorIP != "" { + fields = append(fields, zap.String("actor_ip", event.ActorIP)) + } + if event.TenantID != nil { + fields = append(fields, zap.String("tenant_id", event.TenantID.String())) + } + if event.ResourceID != "" { + fields = append(fields, zap.String("resource_id", event.ResourceID)) + } + if event.ResourceType != "" { + fields = append(fields, zap.String("resource_type", event.ResourceType)) + } + if event.RequestID != "" { + fields = append(fields, zap.String("request_id", event.RequestID)) + } + if event.SessionID != "" { + fields = append(fields, zap.String("session_id", event.SessionID)) + } + if len(event.Tags) > 0 { + fields = append(fields, zap.Strings("tags", event.Tags)) + } + if event.Details != nil { + if detailsJSON, err := json.Marshal(event.Details); err == nil { + fields = append(fields, zap.String("details", string(detailsJSON))) + } + } + + // Log at appropriate level based on severity + switch event.Severity { + case SeverityInfo: + a.logger.Info("Audit event", fields...) + case SeverityWarning: + a.logger.Warn("Audit event", fields...) + case SeverityError: + a.logger.Error("Audit event", fields...) + case SeverityCritical: + a.logger.Error("Critical audit event", fields...) + default: + a.logger.Info("Audit event", fields...) + } +} + +// generateDescription generates a human-readable description for an event +func (a *auditLogger) generateDescription(eventType EventType, details map[string]interface{}) string { + switch eventType { + case EventTypeLogin: + return "User successfully logged in" + case EventTypeLoginFailed: + return "User login attempt failed" + case EventTypeLogout: + return "User logged out" + case EventTypeTokenCreated: + return "API token created" + case EventTypeTokenRevoked: + return "API token revoked" + case EventTypeTokenValidated: + return "API token validated" + case EventTypeSessionCreated: + return "User session created" + case EventTypeSessionRevoked: + return "User session revoked" + case EventTypeSessionExpired: + return "User session expired" + case EventTypeAppCreated: + return "Application created" + case EventTypeAppUpdated: + return "Application updated" + case EventTypeAppDeleted: + return "Application deleted" + case EventTypePermissionGranted: + return "Permission granted" + case EventTypePermissionRevoked: + return "Permission revoked" + case EventTypePermissionDenied: + return "Permission denied" + case EventTypeTenantCreated: + return "Tenant created" + case EventTypeTenantUpdated: + return "Tenant updated" + case EventTypeTenantSuspended: + return "Tenant suspended" + case EventTypeTenantActivated: + return "Tenant activated" + case EventTypeUserCreated: + return "User created" + case EventTypeUserUpdated: + return "User updated" + case EventTypeUserSuspended: + return "User suspended" + case EventTypeUserActivated: + return "User activated" + case EventTypeSecurityViolation: + return "Security violation detected" + case EventTypeBruteForceAttempt: + return "Brute force attempt detected" + case EventTypeIPBlocked: + return "IP address blocked" + case EventTypeRateLimitExceeded: + return "Rate limit exceeded" + case EventTypeSystemStartup: + return "System started" + case EventTypeSystemShutdown: + return "System shutdown" + case EventTypeConfigChanged: + return "Configuration changed" + default: + return string(eventType) + } +} + +// AuditEventBuilder provides a fluent interface for building audit events +type AuditEventBuilder struct { + event *AuditEvent +} + +// NewAuditEventBuilder creates a new audit event builder +func NewAuditEventBuilder(eventType EventType) *AuditEventBuilder { + return &AuditEventBuilder{ + event: &AuditEvent{ + ID: uuid.New(), + Type: eventType, + Timestamp: time.Now().UTC(), + Severity: SeverityInfo, + Status: StatusSuccess, + Details: make(map[string]interface{}), + Metadata: make(map[string]string), + }, + } +} + +// WithSeverity sets the event severity +func (b *AuditEventBuilder) WithSeverity(severity EventSeverity) *AuditEventBuilder { + b.event.Severity = severity + return b +} + +// WithStatus sets the event status +func (b *AuditEventBuilder) WithStatus(status EventStatus) *AuditEventBuilder { + b.event.Status = status + return b +} + +// WithActor sets the actor information +func (b *AuditEventBuilder) WithActor(actorID, actorType, actorIP string) *AuditEventBuilder { + b.event.ActorID = actorID + b.event.ActorType = actorType + b.event.ActorIP = actorIP + return b +} + +// WithTenant sets the tenant ID +func (b *AuditEventBuilder) WithTenant(tenantID uuid.UUID) *AuditEventBuilder { + b.event.TenantID = &tenantID + return b +} + +// WithResource sets the resource information +func (b *AuditEventBuilder) WithResource(resourceID, resourceType string) *AuditEventBuilder { + b.event.ResourceID = resourceID + b.event.ResourceType = resourceType + return b +} + +// WithAction sets the action +func (b *AuditEventBuilder) WithAction(action string) *AuditEventBuilder { + b.event.Action = action + return b +} + +// WithDescription sets the description +func (b *AuditEventBuilder) WithDescription(description string) *AuditEventBuilder { + b.event.Description = description + return b +} + +// WithDetail adds a detail +func (b *AuditEventBuilder) WithDetail(key string, value interface{}) *AuditEventBuilder { + b.event.Details[key] = value + return b +} + +// WithDetails sets multiple details +func (b *AuditEventBuilder) WithDetails(details map[string]interface{}) *AuditEventBuilder { + for k, v := range details { + b.event.Details[k] = v + } + return b +} + +// WithRequestContext sets request context information +func (b *AuditEventBuilder) WithRequestContext(requestID, sessionID string) *AuditEventBuilder { + b.event.RequestID = requestID + b.event.SessionID = sessionID + return b +} + +// WithTags sets the tags +func (b *AuditEventBuilder) WithTags(tags ...string) *AuditEventBuilder { + b.event.Tags = tags + return b +} + +// WithMetadata adds metadata +func (b *AuditEventBuilder) WithMetadata(key, value string) *AuditEventBuilder { + b.event.Metadata[key] = value + return b +} + +// Build returns the built audit event +func (b *AuditEventBuilder) Build() *AuditEvent { + return b.event +} diff --git a/internal/auth/saml.go b/internal/auth/saml.go new file mode 100644 index 0000000..c4e5a32 --- /dev/null +++ b/internal/auth/saml.go @@ -0,0 +1,544 @@ +package auth + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + + "github.com/kms/api-key-service/internal/config" + "github.com/kms/api-key-service/internal/domain" + "github.com/kms/api-key-service/internal/errors" +) + +// SAMLProvider represents a SAML 2.0 identity provider +type SAMLProvider struct { + config config.ConfigProvider + logger *zap.Logger + httpClient *http.Client + privateKey *rsa.PrivateKey + certificate *x509.Certificate +} + +// NewSAMLProvider creates a new SAML provider +func NewSAMLProvider(config config.ConfigProvider, logger *zap.Logger) (*SAMLProvider, error) { + provider := &SAMLProvider{ + config: config, + logger: logger, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } + + // Load SP private key and certificate if configured + if err := provider.loadCredentials(); err != nil { + return nil, err + } + + return provider, nil +} + +// SAMLMetadata represents SAML IdP metadata +type SAMLMetadata struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"` + EntityID string `xml:"entityID,attr"` + IDPSSODescriptor IDPSSODescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata IDPSSODescriptor"` +} + +// IDPSSODescriptor represents the IdP SSO descriptor +type IDPSSODescriptor struct { + ProtocolSupportEnumeration string `xml:"protocolSupportEnumeration,attr"` + KeyDescriptor []KeyDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata KeyDescriptor"` + SingleSignOnService []SingleSignOnService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleSignOnService"` + SingleLogoutService []SingleLogoutService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleLogoutService"` +} + +// KeyDescriptor represents a key descriptor +type KeyDescriptor struct { + Use string `xml:"use,attr"` + KeyInfo KeyInfo `xml:"urn:xmldsig KeyInfo"` +} + +// KeyInfo represents key information +type KeyInfo struct { + X509Data X509Data `xml:"urn:xmldsig X509Data"` +} + +// X509Data represents X509 certificate data +type X509Data struct { + X509Certificate string `xml:"urn:xmldsig X509Certificate"` +} + +// SingleSignOnService represents SSO service endpoint +type SingleSignOnService struct { + Binding string `xml:"Binding,attr"` + Location string `xml:"Location,attr"` +} + +// SingleLogoutService represents SLO service endpoint +type SingleLogoutService struct { + Binding string `xml:"Binding,attr"` + Location string `xml:"Location,attr"` +} + +// SAMLRequest represents a SAML authentication request +type SAMLRequest struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol AuthnRequest"` + ID string `xml:"ID,attr"` + Version string `xml:"Version,attr"` + IssueInstant time.Time `xml:"IssueInstant,attr"` + Destination string `xml:"Destination,attr"` + AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` + ProtocolBinding string `xml:"ProtocolBinding,attr"` + Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` + NameIDPolicy NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"` +} + +// Issuer represents the SAML issuer +type Issuer struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` + Value string `xml:",chardata"` +} + +// NameIDPolicy represents the name ID policy +type NameIDPolicy struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"` + Format string `xml:"Format,attr"` +} + +// SAMLResponse represents a SAML response +type SAMLResponse struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"` + ID string `xml:"ID,attr"` + Version string `xml:"Version,attr"` + IssueInstant time.Time `xml:"IssueInstant,attr"` + Destination string `xml:"Destination,attr"` + InResponseTo string `xml:"InResponseTo,attr"` + Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` + Status Status `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"` + Assertion Assertion `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"` +} + +// Status represents the SAML response status +type Status struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"` + StatusCode StatusCode `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"` +} + +// StatusCode represents the status code +type StatusCode struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"` + Value string `xml:"Value,attr"` +} + +// Assertion represents a SAML assertion +type Assertion struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"` + ID string `xml:"ID,attr"` + Version string `xml:"Version,attr"` + IssueInstant time.Time `xml:"IssueInstant,attr"` + Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` + Subject Subject `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"` + Conditions Conditions `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"` + AttributeStatement AttributeStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"` + AuthnStatement AuthnStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"` +} + +// Subject represents the assertion subject +type Subject struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"` + NameID NameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"` + SubjectConfirmation SubjectConfirmation `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"` +} + +// NameID represents the name identifier +type NameID struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"` + Format string `xml:"Format,attr"` + Value string `xml:",chardata"` +} + +// SubjectConfirmation represents subject confirmation +type SubjectConfirmation struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"` + Method string `xml:"Method,attr"` + SubjectConfirmationData SubjectConfirmationData `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"` +} + +// SubjectConfirmationData represents subject confirmation data +type SubjectConfirmationData struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"` + InResponseTo string `xml:"InResponseTo,attr"` + NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"` + Recipient string `xml:"Recipient,attr"` +} + +// Conditions represents assertion conditions +type Conditions struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"` + NotBefore time.Time `xml:"NotBefore,attr"` + NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"` + AudienceRestriction AudienceRestriction `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"` +} + +// AudienceRestriction represents audience restriction +type AudienceRestriction struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"` + Audience Audience `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"` +} + +// Audience represents the intended audience +type Audience struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"` + Value string `xml:",chardata"` +} + +// AttributeStatement represents attribute statement +type AttributeStatement struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"` + Attribute []Attribute `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"` +} + +// Attribute represents a SAML attribute +type Attribute struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"` + Name string `xml:"Name,attr"` + AttributeValue []AttributeValue `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"` +} + +// AttributeValue represents an attribute value +type AttributeValue struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"` + Type string `xml:"http://www.w3.org/2001/XMLSchema-instance type,attr"` + Value string `xml:",chardata"` +} + +// AuthnStatement represents authentication statement +type AuthnStatement struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"` + AuthnInstant time.Time `xml:"AuthnInstant,attr"` + SessionIndex string `xml:"SessionIndex,attr"` + AuthnContext AuthnContext `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"` +} + +// AuthnContext represents authentication context +type AuthnContext struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"` + AuthnContextClassRef string `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContextClassRef"` +} + +// GetMetadata fetches the SAML IdP metadata +func (p *SAMLProvider) GetMetadata(ctx context.Context) (*SAMLMetadata, error) { + metadataURL := p.config.GetString("SAML_IDP_METADATA_URL") + if metadataURL == "" { + return nil, errors.NewConfigurationError("SAML_IDP_METADATA_URL not configured") + } + + p.logger.Debug("Fetching SAML IdP metadata", zap.String("url", metadataURL)) + + req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) + if err != nil { + return nil, errors.NewInternalError("Failed to create metadata request").WithInternal(err) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, errors.NewInternalError("Failed to fetch IdP metadata").WithInternal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, errors.NewInternalError(fmt.Sprintf("Metadata endpoint returned status %d", resp.StatusCode)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.NewInternalError("Failed to read metadata response").WithInternal(err) + } + + var metadata SAMLMetadata + if err := xml.Unmarshal(body, &metadata); err != nil { + return nil, errors.NewInternalError("Failed to parse SAML metadata").WithInternal(err) + } + + p.logger.Debug("SAML IdP metadata fetched successfully", + zap.String("entity_id", metadata.EntityID)) + + return &metadata, nil +} + +// GenerateAuthRequest generates a SAML authentication request +func (p *SAMLProvider) GenerateAuthRequest(ctx context.Context, relayState string) (string, string, error) { + metadata, err := p.GetMetadata(ctx) + if err != nil { + return "", "", err + } + + // Find SSO endpoint + var ssoEndpoint string + for _, sso := range metadata.IDPSSODescriptor.SingleSignOnService { + if sso.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" { + ssoEndpoint = sso.Location + break + } + } + + if ssoEndpoint == "" { + return "", "", errors.NewConfigurationError("No HTTP-Redirect SSO endpoint found in IdP metadata") + } + + // Generate request ID + requestID := "_" + uuid.New().String() + + // Get SP configuration + spEntityID := p.config.GetString("SAML_SP_ENTITY_ID") + acsURL := p.config.GetString("SAML_SP_ACS_URL") + + if spEntityID == "" { + return "", "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured") + } + if acsURL == "" { + return "", "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured") + } + + // Create SAML request + samlRequest := SAMLRequest{ + ID: requestID, + Version: "2.0", + IssueInstant: time.Now().UTC(), + Destination: ssoEndpoint, + AssertionConsumerServiceURL: acsURL, + ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", + Issuer: Issuer{ + Value: spEntityID, + }, + NameIDPolicy: NameIDPolicy{ + Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress", + }, + } + + // Marshal to XML + xmlData, err := xml.MarshalIndent(samlRequest, "", " ") + if err != nil { + return "", "", errors.NewInternalError("Failed to marshal SAML request").WithInternal(err) + } + + // Add XML declaration + xmlRequest := `` + "\n" + string(xmlData) + + // Base64 encode and URL encode + encodedRequest := base64.StdEncoding.EncodeToString([]byte(xmlRequest)) + + // Build redirect URL + params := url.Values{ + "SAMLRequest": {encodedRequest}, + "RelayState": {relayState}, + } + + redirectURL := ssoEndpoint + "?" + params.Encode() + + p.logger.Debug("Generated SAML authentication request", + zap.String("request_id", requestID), + zap.String("sso_endpoint", ssoEndpoint)) + + return redirectURL, requestID, nil +} + +// ProcessSAMLResponse processes a SAML response and extracts user information +func (p *SAMLProvider) ProcessSAMLResponse(ctx context.Context, samlResponse string, expectedRequestID string) (*domain.AuthContext, error) { + p.logger.Debug("Processing SAML response") + + // Base64 decode the response + decodedResponse, err := base64.StdEncoding.DecodeString(samlResponse) + if err != nil { + return nil, errors.NewValidationError("Failed to decode SAML response").WithInternal(err) + } + + // Parse XML + var response SAMLResponse + if err := xml.Unmarshal(decodedResponse, &response); err != nil { + return nil, errors.NewValidationError("Failed to parse SAML response").WithInternal(err) + } + + // Validate response + if err := p.validateSAMLResponse(&response, expectedRequestID); err != nil { + return nil, err + } + + // Extract user information from assertion + authContext, err := p.extractUserInfo(&response.Assertion) + if err != nil { + return nil, err + } + + p.logger.Debug("SAML response processed successfully", + zap.String("user_id", authContext.UserID)) + + return authContext, nil +} + +// validateSAMLResponse validates a SAML response +func (p *SAMLProvider) validateSAMLResponse(response *SAMLResponse, expectedRequestID string) error { + // Check status + if response.Status.StatusCode.Value != "urn:oasis:names:tc:SAML:2.0:status:Success" { + return errors.NewAuthenticationError("SAML authentication failed: " + response.Status.StatusCode.Value) + } + + // Validate InResponseTo + if expectedRequestID != "" && response.InResponseTo != expectedRequestID { + return errors.NewValidationError("SAML response InResponseTo does not match request ID") + } + + // Validate assertion conditions + assertion := &response.Assertion + now := time.Now().UTC() + + if now.Before(assertion.Conditions.NotBefore) { + return errors.NewValidationError("SAML assertion not yet valid") + } + + if now.After(assertion.Conditions.NotOnOrAfter) { + return errors.NewValidationError("SAML assertion has expired") + } + + // Validate audience + expectedAudience := p.config.GetString("SAML_SP_ENTITY_ID") + if assertion.Conditions.AudienceRestriction.Audience.Value != expectedAudience { + return errors.NewValidationError("SAML assertion audience mismatch") + } + + // In production, you should also validate the signature + // This requires implementing XML signature validation + + return nil +} + +// extractUserInfo extracts user information from SAML assertion +func (p *SAMLProvider) extractUserInfo(assertion *Assertion) (*domain.AuthContext, error) { + // Extract user ID from NameID + userID := assertion.Subject.NameID.Value + if userID == "" { + return nil, errors.NewValidationError("SAML assertion missing NameID") + } + + // Extract attributes + claims := make(map[string]string) + claims["sub"] = userID + claims["name_id_format"] = assertion.Subject.NameID.Format + + // Process attribute statements + for _, attr := range assertion.AttributeStatement.Attribute { + if len(attr.AttributeValue) > 0 { + // Use the first value if multiple values exist + claims[attr.Name] = attr.AttributeValue[0].Value + } + } + + // Map common attributes to standard claims + if email, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"]; exists { + claims["email"] = email + } + if name, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"]; exists { + claims["name"] = name + } + if givenName, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname"]; exists { + claims["given_name"] = givenName + } + if surname, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname"]; exists { + claims["family_name"] = surname + } + + // Extract permissions/roles if available + var permissions []string + if roles, exists := claims["http://schemas.microsoft.com/ws/2008/06/identity/claims/role"]; exists { + permissions = strings.Split(roles, ",") + } + + authContext := &domain.AuthContext{ + UserID: userID, + TokenType: domain.TokenTypeUser, + Claims: claims, + Permissions: permissions, + } + + return authContext, nil +} + +// GenerateServiceProviderMetadata generates SP metadata XML +func (p *SAMLProvider) GenerateServiceProviderMetadata() (string, error) { + spEntityID := p.config.GetString("SAML_SP_ENTITY_ID") + acsURL := p.config.GetString("SAML_SP_ACS_URL") + + if spEntityID == "" { + return "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured") + } + if acsURL == "" { + return "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured") + } + + // This is a simplified SP metadata generation + // In production, you should use a proper SAML library + metadata := fmt.Sprintf(` + + + + +`, spEntityID, acsURL) + + return metadata, nil +} + +// loadCredentials loads SP private key and certificate +func (p *SAMLProvider) loadCredentials() error { + // Load private key if configured + privateKeyPEM := p.config.GetString("SAML_SP_PRIVATE_KEY") + if privateKeyPEM != "" { + block, _ := pem.Decode([]byte(privateKeyPEM)) + if block == nil { + return errors.NewConfigurationError("Failed to decode SAML SP private key") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + // Try PKCS8 format + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return errors.NewConfigurationError("Failed to parse SAML SP private key").WithInternal(err) + } + var ok bool + privateKey, ok = key.(*rsa.PrivateKey) + if !ok { + return errors.NewConfigurationError("SAML SP private key is not RSA") + } + } + p.privateKey = privateKey + } + + // Load certificate if configured + certificatePEM := p.config.GetString("SAML_SP_CERTIFICATE") + if certificatePEM != "" { + block, _ := pem.Decode([]byte(certificatePEM)) + if block == nil { + return errors.NewConfigurationError("Failed to decode SAML SP certificate") + } + + certificate, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return errors.NewConfigurationError("Failed to parse SAML SP certificate").WithInternal(err) + } + p.certificate = certificate + } + + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index c874a40..6c13b77 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -132,6 +132,12 @@ func (c *Config) setDefaults() { "IP_BLOCK_DURATION": "1h", "REQUEST_MAX_AGE": "5m", "IP_WHITELIST": "", + "SAML_ENABLED": "false", + "SAML_IDP_METADATA_URL": "", + "SAML_SP_ENTITY_ID": "", + "SAML_SP_ACS_URL": "", + "SAML_SP_PRIVATE_KEY": "", + "SAML_SP_CERTIFICATE": "", } for key, value := range defaults { diff --git a/internal/domain/models.go b/internal/domain/models.go index f70a0cf..ec6508e 100644 --- a/internal/domain/models.go +++ b/internal/domain/models.go @@ -188,6 +188,23 @@ type CreateStaticTokenResponse struct { CreatedAt time.Time `json:"created_at"` } +// CreateTokenRequest represents a request to create a token +type CreateTokenRequest struct { + AppID string `json:"app_id" validate:"required"` + Type TokenType `json:"type" validate:"required,oneof=static user"` + UserID string `json:"user_id,omitempty"` // Required for user tokens + Permissions []string `json:"permissions,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// CreateTokenResponse represents a response for creating a token +type CreateTokenResponse struct { + Token string `json:"token"` + ExpiresAt time.Time `json:"expires_at"` + TokenType TokenType `json:"token_type"` +} + // AuthContext represents the authentication context for a request type AuthContext struct { UserID string `json:"user_id"` @@ -196,3 +213,25 @@ type AuthContext struct { Claims map[string]string `json:"claims"` AppID string `json:"app_id"` } + +// TokenResponse represents the OAuth2 token response +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// UserInfo represents user information from the OAuth2/OIDC provider +type UserInfo struct { + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` + Picture string `json:"picture"` + PreferredUsername string `json:"preferred_username"` +} diff --git a/internal/domain/session.go b/internal/domain/session.go new file mode 100644 index 0000000..dee5de3 --- /dev/null +++ b/internal/domain/session.go @@ -0,0 +1,153 @@ +package domain + +import ( + "time" + + "github.com/google/uuid" +) + +// SessionStatus represents the status of a user session +type SessionStatus string + +const ( + SessionStatusActive SessionStatus = "active" + SessionStatusExpired SessionStatus = "expired" + SessionStatusRevoked SessionStatus = "revoked" + SessionStatusSuspended SessionStatus = "suspended" +) + +// SessionType represents the type of session +type SessionType string + +const ( + SessionTypeWeb SessionType = "web" + SessionTypeMobile SessionType = "mobile" + SessionTypeAPI SessionType = "api" +) + +// UserSession represents a user session in the system +type UserSession struct { + ID uuid.UUID `json:"id" db:"id"` + UserID string `json:"user_id" validate:"required" db:"user_id"` + AppID string `json:"app_id" validate:"required" db:"app_id"` + SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api" db:"session_type"` + Status SessionStatus `json:"status" validate:"required,oneof=active expired revoked suspended" db:"status"` + AccessToken string `json:"-" db:"access_token"` // Hidden from JSON for security + RefreshToken string `json:"-" db:"refresh_token"` // Hidden from JSON for security + IDToken string `json:"-" db:"id_token"` // Hidden from JSON for security + IPAddress string `json:"ip_address" db:"ip_address"` + UserAgent string `json:"user_agent" db:"user_agent"` + LastActivity time.Time `json:"last_activity" db:"last_activity"` + ExpiresAt time.Time `json:"expires_at" db:"expires_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + RevokedAt *time.Time `json:"revoked_at,omitempty" db:"revoked_at"` + RevokedBy *string `json:"revoked_by,omitempty" db:"revoked_by"` + Metadata SessionMetadata `json:"metadata" db:"metadata"` +} + +// SessionMetadata contains additional session information +type SessionMetadata struct { + DeviceInfo string `json:"device_info,omitempty"` + Location string `json:"location,omitempty"` + LoginMethod string `json:"login_method,omitempty"` + TenantID string `json:"tenant_id,omitempty"` + Permissions []string `json:"permissions,omitempty"` + Claims map[string]string `json:"claims,omitempty"` + RefreshCount int `json:"refresh_count"` + LastRefresh *time.Time `json:"last_refresh,omitempty"` +} + +// CreateSessionRequest represents a request to create a new session +type CreateSessionRequest struct { + UserID string `json:"user_id" validate:"required"` + AppID string `json:"app_id" validate:"required"` + SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api"` + IPAddress string `json:"ip_address" validate:"required,ip"` + UserAgent string `json:"user_agent" validate:"required"` + ExpiresAt time.Time `json:"expires_at" validate:"required"` + Permissions []string `json:"permissions,omitempty"` + Claims map[string]string `json:"claims,omitempty"` + TenantID string `json:"tenant_id,omitempty"` +} + +// UpdateSessionRequest represents a request to update a session +type UpdateSessionRequest struct { + Status *SessionStatus `json:"status,omitempty" validate:"omitempty,oneof=active expired revoked suspended"` + LastActivity *time.Time `json:"last_activity,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + IPAddress *string `json:"ip_address,omitempty" validate:"omitempty,ip"` + UserAgent *string `json:"user_agent,omitempty"` +} + +// SessionListRequest represents a request to list sessions +type SessionListRequest struct { + UserID string `json:"user_id,omitempty"` + AppID string `json:"app_id,omitempty"` + Status *SessionStatus `json:"status,omitempty"` + SessionType *SessionType `json:"session_type,omitempty"` + TenantID string `json:"tenant_id,omitempty"` + Limit int `json:"limit" validate:"min=1,max=100"` + Offset int `json:"offset" validate:"min=0"` +} + +// SessionListResponse represents a response for listing sessions +type SessionListResponse struct { + Sessions []*UserSession `json:"sessions"` + Total int `json:"total"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +// IsActive checks if the session is currently active +func (s *UserSession) IsActive() bool { + return s.Status == SessionStatusActive && time.Now().Before(s.ExpiresAt) +} + +// IsExpired checks if the session has expired +func (s *UserSession) IsExpired() bool { + return time.Now().After(s.ExpiresAt) || s.Status == SessionStatusExpired +} + +// IsRevoked checks if the session has been revoked +func (s *UserSession) IsRevoked() bool { + return s.Status == SessionStatusRevoked +} + +// CanRefresh checks if the session can be refreshed +func (s *UserSession) CanRefresh() bool { + return s.IsActive() && s.RefreshToken != "" +} + +// UpdateActivity updates the last activity timestamp +func (s *UserSession) UpdateActivity() { + s.LastActivity = time.Now() + s.UpdatedAt = time.Now() +} + +// Revoke marks the session as revoked +func (s *UserSession) Revoke(revokedBy string) { + now := time.Now() + s.Status = SessionStatusRevoked + s.RevokedAt = &now + s.RevokedBy = &revokedBy + s.UpdatedAt = now +} + +// Expire marks the session as expired +func (s *UserSession) Expire() { + s.Status = SessionStatusExpired + s.UpdatedAt = time.Now() +} + +// Suspend marks the session as suspended +func (s *UserSession) Suspend() { + s.Status = SessionStatusSuspended + s.UpdatedAt = time.Now() +} + +// Activate marks the session as active +func (s *UserSession) Activate() { + s.Status = SessionStatusActive + s.UpdatedAt = time.Now() +} diff --git a/internal/domain/tenant.go b/internal/domain/tenant.go new file mode 100644 index 0000000..bf2313f --- /dev/null +++ b/internal/domain/tenant.go @@ -0,0 +1,307 @@ +package domain + +import ( + "time" + + "github.com/google/uuid" +) + +// TenantStatus represents the status of a tenant +type TenantStatus string + +const ( + TenantStatusActive TenantStatus = "active" + TenantStatusSuspended TenantStatus = "suspended" + TenantStatusInactive TenantStatus = "inactive" +) + +// Tenant represents a tenant in the multi-tenant system +type Tenant struct { + ID uuid.UUID `json:"id" db:"id"` + Name string `json:"name" validate:"required,min=1,max=255" db:"name"` + Slug string `json:"slug" validate:"required,min=1,max=100,alphanum" db:"slug"` + Status TenantStatus `json:"status" validate:"required,oneof=active suspended inactive" db:"status"` + Domain string `json:"domain,omitempty" validate:"omitempty,fqdn" db:"domain"` + Description string `json:"description,omitempty" validate:"max=1000" db:"description"` + Settings TenantSettings `json:"settings" db:"settings"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + CreatedBy string `json:"created_by" db:"created_by"` + UpdatedBy string `json:"updated_by" db:"updated_by"` +} + +// TenantSettings contains tenant-specific configuration +type TenantSettings struct { + // Authentication settings + AuthProvider string `json:"auth_provider,omitempty"` // oauth2, saml, header + SAMLSettings *SAMLSettings `json:"saml_settings,omitempty"` + OAuth2Settings *OAuth2Settings `json:"oauth2_settings,omitempty"` + + // Session settings + SessionTimeout time.Duration `json:"session_timeout,omitempty"` + MaxConcurrentSessions int `json:"max_concurrent_sessions,omitempty"` + + // Security settings + RequireMFA bool `json:"require_mfa"` + AllowedIPRanges []string `json:"allowed_ip_ranges,omitempty"` + PasswordPolicy *PasswordPolicy `json:"password_policy,omitempty"` + + // Token settings + DefaultTokenDuration time.Duration `json:"default_token_duration,omitempty"` + MaxTokenDuration time.Duration `json:"max_token_duration,omitempty"` + + // Feature flags + Features map[string]bool `json:"features,omitempty"` + + // Custom attributes + CustomAttributes map[string]string `json:"custom_attributes,omitempty"` +} + +// SAMLSettings contains SAML-specific configuration for a tenant +type SAMLSettings struct { + IDPMetadataURL string `json:"idp_metadata_url,omitempty"` + SPEntityID string `json:"sp_entity_id,omitempty"` + ACSURL string `json:"acs_url,omitempty"` + SPPrivateKey string `json:"sp_private_key,omitempty"` + SPCertificate string `json:"sp_certificate,omitempty"` + AttributeMapping map[string]string `json:"attribute_mapping,omitempty"` +} + +// OAuth2Settings contains OAuth2-specific configuration for a tenant +type OAuth2Settings struct { + ProviderURL string `json:"provider_url,omitempty"` + ClientID string `json:"client_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + Scopes []string `json:"scopes,omitempty"` + AttributeMapping map[string]string `json:"attribute_mapping,omitempty"` +} + +// PasswordPolicy defines password requirements for a tenant +type PasswordPolicy struct { + MinLength int `json:"min_length"` + RequireUppercase bool `json:"require_uppercase"` + RequireLowercase bool `json:"require_lowercase"` + RequireNumbers bool `json:"require_numbers"` + RequireSymbols bool `json:"require_symbols"` + MaxAge time.Duration `json:"max_age,omitempty"` + PreventReuse int `json:"prevent_reuse"` // Number of previous passwords to prevent reuse +} + +// TenantUser represents a user within a specific tenant +type TenantUser struct { + ID uuid.UUID `json:"id" db:"id"` + TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"` + UserID string `json:"user_id" validate:"required" db:"user_id"` + Email string `json:"email" validate:"required,email" db:"email"` + Name string `json:"name" validate:"required" db:"name"` + Roles []string `json:"roles" db:"roles"` + Permissions []string `json:"permissions" db:"permissions"` + Status UserStatus `json:"status" validate:"required,oneof=active inactive suspended" db:"status"` + Metadata map[string]string `json:"metadata,omitempty" db:"metadata"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + LastLoginAt *time.Time `json:"last_login_at,omitempty" db:"last_login_at"` +} + +// UserStatus represents the status of a user within a tenant +type UserStatus string + +const ( + UserStatusActive UserStatus = "active" + UserStatusInactive UserStatus = "inactive" + UserStatusSuspended UserStatus = "suspended" +) + +// TenantRole represents a role within a tenant +type TenantRole struct { + ID uuid.UUID `json:"id" db:"id"` + TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"` + Name string `json:"name" validate:"required,min=1,max=100" db:"name"` + Description string `json:"description,omitempty" validate:"max=500" db:"description"` + Permissions []string `json:"permissions" db:"permissions"` + IsSystem bool `json:"is_system" db:"is_system"` // System roles cannot be deleted + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + CreatedBy string `json:"created_by" db:"created_by"` + UpdatedBy string `json:"updated_by" db:"updated_by"` +} + +// CreateTenantRequest represents a request to create a new tenant +type CreateTenantRequest struct { + Name string `json:"name" validate:"required,min=1,max=255"` + Slug string `json:"slug" validate:"required,min=1,max=100,alphanum"` + Domain string `json:"domain,omitempty" validate:"omitempty,fqdn"` + Description string `json:"description,omitempty" validate:"max=1000"` + Settings TenantSettings `json:"settings,omitempty"` +} + +// UpdateTenantRequest represents a request to update a tenant +type UpdateTenantRequest struct { + Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=255"` + Status *TenantStatus `json:"status,omitempty" validate:"omitempty,oneof=active suspended inactive"` + Domain *string `json:"domain,omitempty" validate:"omitempty,fqdn"` + Description *string `json:"description,omitempty" validate:"omitempty,max=1000"` + Settings *TenantSettings `json:"settings,omitempty"` +} + +// CreateTenantUserRequest represents a request to create a user in a tenant +type CreateTenantUserRequest struct { + TenantID uuid.UUID `json:"tenant_id" validate:"required"` + UserID string `json:"user_id" validate:"required"` + Email string `json:"email" validate:"required,email"` + Name string `json:"name" validate:"required"` + Roles []string `json:"roles,omitempty"` + Permissions []string `json:"permissions,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// UpdateTenantUserRequest represents a request to update a tenant user +type UpdateTenantUserRequest struct { + Email *string `json:"email,omitempty" validate:"omitempty,email"` + Name *string `json:"name,omitempty" validate:"omitempty,min=1"` + Roles []string `json:"roles,omitempty"` + Permissions []string `json:"permissions,omitempty"` + Status *UserStatus `json:"status,omitempty" validate:"omitempty,oneof=active inactive suspended"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// CreateTenantRoleRequest represents a request to create a role in a tenant +type CreateTenantRoleRequest struct { + TenantID uuid.UUID `json:"tenant_id" validate:"required"` + Name string `json:"name" validate:"required,min=1,max=100"` + Description string `json:"description,omitempty" validate:"max=500"` + Permissions []string `json:"permissions,omitempty"` +} + +// UpdateTenantRoleRequest represents a request to update a tenant role +type UpdateTenantRoleRequest struct { + Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=100"` + Description *string `json:"description,omitempty" validate:"omitempty,max=500"` + Permissions []string `json:"permissions,omitempty"` +} + +// TenantListRequest represents a request to list tenants +type TenantListRequest struct { + Status *TenantStatus `json:"status,omitempty"` + Domain string `json:"domain,omitempty"` + Limit int `json:"limit" validate:"min=1,max=100"` + Offset int `json:"offset" validate:"min=0"` +} + +// TenantListResponse represents a response for listing tenants +type TenantListResponse struct { + Tenants []*Tenant `json:"tenants"` + Total int `json:"total"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +// IsActive checks if the tenant is active +func (t *Tenant) IsActive() bool { + return t.Status == TenantStatusActive +} + +// IsSuspended checks if the tenant is suspended +func (t *Tenant) IsSuspended() bool { + return t.Status == TenantStatusSuspended +} + +// HasFeature checks if a feature is enabled for the tenant +func (t *Tenant) HasFeature(feature string) bool { + if t.Settings.Features == nil { + return false + } + enabled, exists := t.Settings.Features[feature] + return exists && enabled +} + +// GetAuthProvider returns the authentication provider for the tenant +func (t *Tenant) GetAuthProvider() string { + if t.Settings.AuthProvider != "" { + return t.Settings.AuthProvider + } + return "header" // default +} + +// GetSessionTimeout returns the session timeout for the tenant +func (t *Tenant) GetSessionTimeout() time.Duration { + if t.Settings.SessionTimeout > 0 { + return t.Settings.SessionTimeout + } + return 8 * time.Hour // default +} + +// GetMaxConcurrentSessions returns the maximum concurrent sessions for the tenant +func (t *Tenant) GetMaxConcurrentSessions() int { + if t.Settings.MaxConcurrentSessions > 0 { + return t.Settings.MaxConcurrentSessions + } + return 10 // default +} + +// IsActive checks if the tenant user is active +func (tu *TenantUser) IsActive() bool { + return tu.Status == UserStatusActive +} + +// IsSuspended checks if the tenant user is suspended +func (tu *TenantUser) IsSuspended() bool { + return tu.Status == UserStatusSuspended +} + +// HasRole checks if the user has a specific role +func (tu *TenantUser) HasRole(role string) bool { + for _, r := range tu.Roles { + if r == role { + return true + } + } + return false +} + +// HasPermission checks if the user has a specific permission +func (tu *TenantUser) HasPermission(permission string) bool { + for _, p := range tu.Permissions { + if p == permission { + return true + } + } + return false +} + +// UpdateLastLogin updates the last login timestamp +func (tu *TenantUser) UpdateLastLogin() { + now := time.Now() + tu.LastLoginAt = &now + tu.UpdatedAt = now +} + +// IsSystemRole checks if the role is a system role +func (tr *TenantRole) IsSystemRole() bool { + return tr.IsSystem +} + +// HasPermission checks if the role has a specific permission +func (tr *TenantRole) HasPermission(permission string) bool { + for _, p := range tr.Permissions { + if p == permission { + return true + } + } + return false +} + +// TenantContext represents the tenant context for a request +type TenantContext struct { + TenantID uuid.UUID `json:"tenant_id"` + TenantSlug string `json:"tenant_slug"` + UserID string `json:"user_id"` + Roles []string `json:"roles"` + Permissions []string `json:"permissions"` +} + +// MultiTenantAuthContext extends AuthContext with tenant information +type MultiTenantAuthContext struct { + *AuthContext + TenantContext *TenantContext `json:"tenant_context,omitempty"` +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 902ee5e..a0b8567 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -295,3 +295,66 @@ func (c *Chain) Error() string { } return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1) } + +// Helper functions to check error types + +// IsNotFound checks if an error is a not found error +func IsNotFound(err error) bool { + if appErr, ok := err.(*AppError); ok { + return appErr.Code == ErrNotFound || appErr.Code == ErrApplicationNotFound || + appErr.Code == ErrTokenNotFound || appErr.Code == ErrPermissionNotFound + } + return false +} + +// IsValidationError checks if an error is a validation error +func IsValidationError(err error) bool { + if appErr, ok := err.(*AppError); ok { + return appErr.Code == ErrValidationFailed || appErr.Code == ErrInvalidInput || + appErr.Code == ErrMissingField || appErr.Code == ErrInvalidFormat + } + return false +} + +// IsAuthenticationError checks if an error is an authentication error +func IsAuthenticationError(err error) bool { + if appErr, ok := err.(*AppError); ok { + return appErr.Code == ErrUnauthorized || appErr.Code == ErrInvalidToken || + appErr.Code == ErrTokenExpired || appErr.Code == ErrInvalidCredentials + } + return false +} + +// IsAuthorizationError checks if an error is an authorization error +func IsAuthorizationError(err error) bool { + if appErr, ok := err.(*AppError); ok { + return appErr.Code == ErrForbidden || appErr.Code == ErrInsufficientPermissions + } + return false +} + +// IsConflictError checks if an error is a conflict error +func IsConflictError(err error) bool { + if appErr, ok := err.(*AppError); ok { + return appErr.Code == ErrAlreadyExists || appErr.Code == ErrConflict + } + return false +} + +// IsInternalError checks if an error is an internal server error +func IsInternalError(err error) bool { + if appErr, ok := err.(*AppError); ok { + return appErr.Code == ErrInternal || appErr.Code == ErrDatabase || + appErr.Code == ErrExternal || appErr.Code == ErrTokenCreationFailed + } + return false +} + +// IsConfigurationError checks if an error is a configuration error +func IsConfigurationError(err error) bool { + if appErr, ok := err.(*AppError); ok { + // Configuration errors are typically mapped to internal errors + return appErr.Code == ErrInternal && appErr.Message != "" + } + return false +} diff --git a/internal/handlers/saml.go b/internal/handlers/saml.go new file mode 100644 index 0000000..860226e --- /dev/null +++ b/internal/handlers/saml.go @@ -0,0 +1,352 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + "go.uber.org/zap" + + "github.com/kms/api-key-service/internal/auth" + "github.com/kms/api-key-service/internal/config" + "github.com/kms/api-key-service/internal/domain" + "github.com/kms/api-key-service/internal/errors" + "github.com/kms/api-key-service/internal/services" +) + +// SAMLHandler handles SAML authentication endpoints +type SAMLHandler struct { + samlProvider *auth.SAMLProvider + sessionService services.SessionService + authService services.AuthenticationService + tokenService services.TokenService + config config.ConfigProvider + logger *zap.Logger +} + +// NewSAMLHandler creates a new SAML handler +func NewSAMLHandler( + config config.ConfigProvider, + sessionService services.SessionService, + authService services.AuthenticationService, + tokenService services.TokenService, + logger *zap.Logger, +) (*SAMLHandler, error) { + samlProvider, err := auth.NewSAMLProvider(config, logger) + if err != nil { + return nil, err + } + + return &SAMLHandler{ + samlProvider: samlProvider, + sessionService: sessionService, + authService: authService, + config: config, + logger: logger, + }, nil +} + +// RegisterRoutes registers SAML routes +func (h *SAMLHandler) RegisterRoutes(router *mux.Router) { + // SAML endpoints + router.HandleFunc("/auth/saml/login", h.InitiateSAMLLogin).Methods("GET") + router.HandleFunc("/auth/saml/acs", h.HandleSAMLResponse).Methods("POST") + router.HandleFunc("/auth/saml/metadata", h.GetServiceProviderMetadata).Methods("GET") + router.HandleFunc("/auth/saml/slo", h.HandleSingleLogout).Methods("GET", "POST") +} + +// InitiateSAMLLogin initiates SAML authentication +func (h *SAMLHandler) InitiateSAMLLogin(w http.ResponseWriter, r *http.Request) { + if !h.config.GetBool("SAML_ENABLED") { + h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) + return + } + + // Get query parameters + appID := r.URL.Query().Get("app_id") + redirectURL := r.URL.Query().Get("redirect_url") + + if appID == "" { + h.writeErrorResponse(w, errors.NewValidationError("app_id parameter is required")) + return + } + + // Generate relay state with app_id and redirect_url + relayState := appID + if redirectURL != "" { + relayState += "|" + redirectURL + } + + h.logger.Debug("Initiating SAML login", + zap.String("app_id", appID), + zap.String("redirect_url", redirectURL)) + + // Generate SAML authentication request + authURL, requestID, err := h.samlProvider.GenerateAuthRequest(r.Context(), relayState) + if err != nil { + h.logger.Error("Failed to generate SAML auth request", zap.Error(err)) + h.writeErrorResponse(w, err) + return + } + + // Store request ID in session/cache for validation + // In production, you should store this securely + h.logger.Debug("Generated SAML auth request", + zap.String("request_id", requestID), + zap.String("auth_url", authURL)) + + // Redirect to IdP + http.Redirect(w, r, authURL, http.StatusFound) +} + +// HandleSAMLResponse handles SAML assertion consumer service (ACS) +func (h *SAMLHandler) HandleSAMLResponse(w http.ResponseWriter, r *http.Request) { + if !h.config.GetBool("SAML_ENABLED") { + h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) + return + } + + h.logger.Debug("Handling SAML response") + + // Parse form data + if err := r.ParseForm(); err != nil { + h.writeErrorResponse(w, errors.NewValidationError("Failed to parse form data").WithInternal(err)) + return + } + + samlResponse := r.FormValue("SAMLResponse") + relayState := r.FormValue("RelayState") + + if samlResponse == "" { + h.writeErrorResponse(w, errors.NewValidationError("SAMLResponse is required")) + return + } + + h.logger.Debug("Processing SAML response", zap.String("relay_state", relayState)) + + // Process SAML response + // In production, you should retrieve and validate the original request ID + authContext, err := h.samlProvider.ProcessSAMLResponse(r.Context(), samlResponse, "") + if err != nil { + h.logger.Error("Failed to process SAML response", zap.Error(err)) + h.writeErrorResponse(w, err) + return + } + + // Parse relay state to get app_id and redirect_url + appID, redirectURL := h.parseRelayState(relayState) + if appID == "" { + h.writeErrorResponse(w, errors.NewValidationError("Invalid relay state: missing app_id")) + return + } + + // Create user session + sessionReq := &domain.CreateSessionRequest{ + UserID: authContext.UserID, + AppID: appID, + SessionType: domain.SessionTypeWeb, + IPAddress: h.getClientIP(r), + UserAgent: r.UserAgent(), + ExpiresAt: time.Now().Add(8 * time.Hour), // 8 hour session + Permissions: authContext.Permissions, + Claims: authContext.Claims, + } + + session, err := h.sessionService.CreateSession(r.Context(), sessionReq) + if err != nil { + h.logger.Error("Failed to create session", zap.Error(err)) + h.writeErrorResponse(w, err) + return + } + + // Generate JWT token for the session using the existing token service + userToken := &domain.UserToken{ + AppID: appID, + UserID: authContext.UserID, + Permissions: authContext.Permissions, + IssuedAt: time.Now(), + ExpiresAt: session.ExpiresAt, + MaxValidAt: session.ExpiresAt, + TokenType: domain.TokenTypeUser, + Claims: authContext.Claims, + } + + tokenString, err := h.authService.GenerateJWTToken(r.Context(), userToken) + if err != nil { + h.logger.Error("Failed to create JWT token", zap.Error(err)) + h.writeErrorResponse(w, err) + return + } + + h.logger.Debug("SAML authentication successful", + zap.String("user_id", authContext.UserID), + zap.String("session_id", session.ID.String())) + + // If redirect URL is provided, redirect with token + if redirectURL != "" { + // Add token as query parameter or fragment + redirectURL += "?token=" + tokenString + http.Redirect(w, r, redirectURL, http.StatusFound) + return + } + + // Otherwise, return JSON response + response := map[string]interface{}{ + "success": true, + "token": tokenString, + "user": map[string]interface{}{ + "id": authContext.UserID, + "email": authContext.Claims["email"], + "name": authContext.Claims["name"], + }, + "session_id": session.ID.String(), + "expires_at": session.ExpiresAt, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// GetServiceProviderMetadata returns SP metadata XML +func (h *SAMLHandler) GetServiceProviderMetadata(w http.ResponseWriter, r *http.Request) { + if !h.config.GetBool("SAML_ENABLED") { + h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) + return + } + + h.logger.Debug("Generating SP metadata") + + metadata, err := h.samlProvider.GenerateServiceProviderMetadata() + if err != nil { + h.logger.Error("Failed to generate SP metadata", zap.Error(err)) + h.writeErrorResponse(w, err) + return + } + + w.Header().Set("Content-Type", "application/xml") + w.Write([]byte(metadata)) +} + +// HandleSingleLogout handles SAML single logout +func (h *SAMLHandler) HandleSingleLogout(w http.ResponseWriter, r *http.Request) { + if !h.config.GetBool("SAML_ENABLED") { + h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled")) + return + } + + h.logger.Debug("Handling SAML single logout") + + // Get session ID from query parameter or form + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" && r.Method == "POST" { + r.ParseForm() + sessionID = r.FormValue("session_id") + } + + if sessionID != "" { + // Revoke specific session + h.logger.Debug("Revoking session", zap.String("session_id", sessionID)) + // Implementation would depend on how you store session IDs + // For now, we'll just log it + } + + // In a full implementation, you would: + // 1. Parse the SAML LogoutRequest + // 2. Validate the request + // 3. Revoke the user's sessions + // 4. Generate a LogoutResponse + // 5. Redirect back to the IdP + + // For now, return a simple success response + response := map[string]interface{}{ + "success": true, + "message": "Logout successful", + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// parseRelayState parses the relay state to extract app_id and redirect_url +func (h *SAMLHandler) parseRelayState(relayState string) (appID, redirectURL string) { + if relayState == "" { + return "", "" + } + + // RelayState format: "app_id|redirect_url" or just "app_id" + parts := []string{relayState} + if len(relayState) > 0 && relayState[0] != '|' { + // Split on first pipe character + for i, char := range relayState { + if char == '|' { + parts = []string{relayState[:i], relayState[i+1:]} + break + } + } + } + + appID = parts[0] + if len(parts) > 1 { + redirectURL = parts[1] + } + + return appID, redirectURL +} + +// getClientIP extracts the client IP address from the request +func (h *SAMLHandler) getClientIP(r *http.Request) string { + // Check X-Forwarded-For header first + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP if multiple are present + if idx := len(xff); idx > 0 { + for i, char := range xff { + if char == ',' { + return xff[:i] + } + } + return xff + } + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + return r.RemoteAddr +} + +// writeErrorResponse writes an error response +func (h *SAMLHandler) writeErrorResponse(w http.ResponseWriter, err error) { + var statusCode int + var errorCode string + + switch { + case errors.IsValidationError(err): + statusCode = http.StatusBadRequest + errorCode = "VALIDATION_ERROR" + case errors.IsAuthenticationError(err): + statusCode = http.StatusUnauthorized + errorCode = "AUTHENTICATION_ERROR" + case errors.IsConfigurationError(err): + statusCode = http.StatusServiceUnavailable + errorCode = "CONFIGURATION_ERROR" + default: + statusCode = http.StatusInternalServerError + errorCode = "INTERNAL_ERROR" + } + + response := map[string]interface{}{ + "success": false, + "error": map[string]interface{}{ + "code": errorCode, + "message": err.Error(), + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(response) +} diff --git a/internal/middleware/security.go b/internal/middleware/security.go index 2624b1b..f10cdc1 100644 --- a/internal/middleware/security.go +++ b/internal/middleware/security.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "fmt" "net" "net/http" "strings" @@ -14,7 +13,6 @@ import ( "github.com/kms/api-key-service/internal/cache" "github.com/kms/api-key-service/internal/config" - "github.com/kms/api-key-service/internal/errors" ) // SecurityMiddleware provides various security features @@ -408,8 +406,6 @@ func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool { // GetSecurityMetrics returns security-related metrics func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} { - ctx := context.Background() - // This is a simplified version - in production you'd want more comprehensive metrics metrics := map[string]interface{}{ "active_rate_limiters": len(s.rateLimiters), diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 0958294..065c23c 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -104,6 +104,57 @@ type GrantedPermissionRepository interface { HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) } +// SessionRepository defines the interface for user session data operations +type SessionRepository interface { + // Create creates a new user session + Create(ctx context.Context, session *domain.UserSession) error + + // GetByID retrieves a session by its ID + GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) + + // GetByUserID retrieves all sessions for a user + GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) + + // GetByUserAndApp retrieves sessions for a specific user and application + GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) + + // GetActiveByUserID retrieves all active sessions for a user + GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) + + // List retrieves sessions with filtering and pagination + List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) + + // Update updates an existing session + Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error + + // UpdateActivity updates the last activity timestamp for a session + UpdateActivity(ctx context.Context, sessionID uuid.UUID) error + + // Revoke revokes a session + Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error + + // RevokeAllByUser revokes all sessions for a user + RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error + + // RevokeAllByUserAndApp revokes all sessions for a user and application + RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error + + // ExpireOldSessions marks expired sessions as expired + ExpireOldSessions(ctx context.Context) (int, error) + + // DeleteExpiredSessions removes expired sessions older than the specified duration + DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) + + // Exists checks if a session exists + Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) + + // GetSessionCount returns the total number of sessions for a user + GetSessionCount(ctx context.Context, userID string) (int, error) + + // GetActiveSessionCount returns the number of active sessions for a user + GetActiveSessionCount(ctx context.Context, userID string) (int, error) +} + // DatabaseProvider defines the interface for database operations type DatabaseProvider interface { // GetDB returns the underlying database connection diff --git a/internal/repository/postgres/session_repository.go b/internal/repository/postgres/session_repository.go new file mode 100644 index 0000000..bb5000f --- /dev/null +++ b/internal/repository/postgres/session_repository.go @@ -0,0 +1,624 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" + "go.uber.org/zap" + + "github.com/kms/api-key-service/internal/domain" + "github.com/kms/api-key-service/internal/errors" + "github.com/kms/api-key-service/internal/repository" +) + +// sessionRepository implements the SessionRepository interface +type sessionRepository struct { + db *sqlx.DB + logger *zap.Logger +} + +// NewSessionRepository creates a new session repository +func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository { + return &sessionRepository{ + db: db, + logger: logger, + } +} + +// Create creates a new user session +func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error { + r.logger.Debug("Creating new session", + zap.String("user_id", session.UserID), + zap.String("app_id", session.AppID), + zap.String("session_type", string(session.SessionType))) + + // Generate ID if not provided + if session.ID == uuid.Nil { + session.ID = uuid.New() + } + + // Set timestamps + now := time.Now() + session.CreatedAt = now + session.UpdatedAt = now + session.LastActivity = now + + // Serialize metadata + metadataJSON, err := json.Marshal(session.Metadata) + if err != nil { + return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err) + } + + query := ` + INSERT INTO user_sessions ( + id, user_id, app_id, session_type, status, access_token, + refresh_token, id_token, ip_address, user_agent, + last_activity, expires_at, created_at, updated_at, metadata + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15 + )` + + _, err = r.db.ExecContext(ctx, query, + session.ID, + session.UserID, + session.AppID, + session.SessionType, + session.Status, + session.AccessToken, + session.RefreshToken, + session.IDToken, + session.IPAddress, + session.UserAgent, + session.LastActivity, + session.ExpiresAt, + session.CreatedAt, + session.UpdatedAt, + metadataJSON, + ) + + if err != nil { + r.logger.Error("Failed to create session", zap.Error(err)) + return errors.NewInternalError("Failed to create session").WithInternal(err) + } + + r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String())) + return nil +} + +// GetByID retrieves a session by its ID +func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) { + r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String())) + + query := ` + SELECT id, user_id, app_id, session_type, status, access_token, + refresh_token, id_token, ip_address, user_agent, + last_activity, expires_at, created_at, updated_at, + revoked_at, revoked_by, metadata + FROM user_sessions + WHERE id = $1` + + var session domain.UserSession + var metadataJSON []byte + var revokedAt sql.NullTime + var revokedBy sql.NullString + + err := r.db.QueryRowContext(ctx, query, sessionID).Scan( + &session.ID, + &session.UserID, + &session.AppID, + &session.SessionType, + &session.Status, + &session.AccessToken, + &session.RefreshToken, + &session.IDToken, + &session.IPAddress, + &session.UserAgent, + &session.LastActivity, + &session.ExpiresAt, + &session.CreatedAt, + &session.UpdatedAt, + &revokedAt, + &revokedBy, + &metadataJSON, + ) + + if err != nil { + if err == sql.ErrNoRows { + return nil, errors.NewNotFoundError("Session not found") + } + r.logger.Error("Failed to get session by ID", zap.Error(err)) + return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err) + } + + // Handle nullable fields + if revokedAt.Valid { + session.RevokedAt = &revokedAt.Time + } + if revokedBy.Valid { + session.RevokedBy = &revokedBy.String + } + + // Deserialize metadata + if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil { + r.logger.Warn("Failed to deserialize session metadata", zap.Error(err)) + session.Metadata = domain.SessionMetadata{} // Use empty metadata on error + } + + r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String())) + return &session, nil +} + +// GetByUserID retrieves all sessions for a user +func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) { + r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID)) + + query := ` + SELECT id, user_id, app_id, session_type, status, access_token, + refresh_token, id_token, ip_address, user_agent, + last_activity, expires_at, created_at, updated_at, + revoked_at, revoked_by, metadata + FROM user_sessions + WHERE user_id = $1 + ORDER BY created_at DESC` + + return r.scanSessions(ctx, query, userID) +} + +// GetByUserAndApp retrieves sessions for a specific user and application +func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) { + r.logger.Debug("Getting sessions by user and app", + zap.String("user_id", userID), + zap.String("app_id", appID)) + + query := ` + SELECT id, user_id, app_id, session_type, status, access_token, + refresh_token, id_token, ip_address, user_agent, + last_activity, expires_at, created_at, updated_at, + revoked_at, revoked_by, metadata + FROM user_sessions + WHERE user_id = $1 AND app_id = $2 + ORDER BY created_at DESC` + + return r.scanSessions(ctx, query, userID, appID) +} + +// GetActiveByUserID retrieves all active sessions for a user +func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) { + r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID)) + + query := ` + SELECT id, user_id, app_id, session_type, status, access_token, + refresh_token, id_token, ip_address, user_agent, + last_activity, expires_at, created_at, updated_at, + revoked_at, revoked_by, metadata + FROM user_sessions + WHERE user_id = $1 AND status = $2 AND expires_at > NOW() + ORDER BY last_activity DESC` + + return r.scanSessions(ctx, query, userID, domain.SessionStatusActive) +} + +// List retrieves sessions with filtering and pagination +func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) { + r.logger.Debug("Listing sessions with filters", + zap.String("user_id", req.UserID), + zap.String("app_id", req.AppID), + zap.Int("limit", req.Limit), + zap.Int("offset", req.Offset)) + + // Build WHERE clause dynamically + whereClause := "WHERE 1=1" + args := []interface{}{} + argIndex := 1 + + if req.UserID != "" { + whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex) + args = append(args, req.UserID) + argIndex++ + } + + if req.AppID != "" { + whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex) + args = append(args, req.AppID) + argIndex++ + } + + if req.Status != nil { + whereClause += fmt.Sprintf(" AND status = $%d", argIndex) + args = append(args, *req.Status) + argIndex++ + } + + if req.SessionType != nil { + whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex) + args = append(args, *req.SessionType) + argIndex++ + } + + if req.TenantID != "" { + whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex) + args = append(args, req.TenantID) + argIndex++ + } + + // Get total count + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause) + var total int + err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total) + if err != nil { + r.logger.Error("Failed to get session count", zap.Error(err)) + return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err) + } + + // Get sessions with pagination + query := fmt.Sprintf(` + SELECT id, user_id, app_id, session_type, status, access_token, + refresh_token, id_token, ip_address, user_agent, + last_activity, expires_at, created_at, updated_at, + revoked_at, revoked_by, metadata + FROM user_sessions + %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1) + + args = append(args, req.Limit, req.Offset) + sessions, err := r.scanSessions(ctx, query, args...) + if err != nil { + return nil, err + } + + return &domain.SessionListResponse{ + Sessions: sessions, + Total: total, + Limit: req.Limit, + Offset: req.Offset, + }, nil +} + +// Update updates an existing session +func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error { + r.logger.Debug("Updating session", zap.String("session_id", sessionID.String())) + + // Build UPDATE clause dynamically + setParts := []string{"updated_at = NOW()"} + args := []interface{}{} + argIndex := 1 + + if updates.Status != nil { + setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex)) + args = append(args, *updates.Status) + argIndex++ + } + + if updates.LastActivity != nil { + setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex)) + args = append(args, *updates.LastActivity) + argIndex++ + } + + if updates.ExpiresAt != nil { + setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex)) + args = append(args, *updates.ExpiresAt) + argIndex++ + } + + if updates.IPAddress != nil { + setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex)) + args = append(args, *updates.IPAddress) + argIndex++ + } + + if updates.UserAgent != nil { + setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex)) + args = append(args, *updates.UserAgent) + argIndex++ + } + + if len(setParts) == 1 { + return errors.NewValidationError("No fields to update") + } + + // Build the complete query + setClause := fmt.Sprintf("%s", setParts[0]) + for i := 1; i < len(setParts); i++ { + setClause += fmt.Sprintf(", %s", setParts[i]) + } + + query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex) + args = append(args, sessionID) + + result, err := r.db.ExecContext(ctx, query, args...) + if err != nil { + r.logger.Error("Failed to update session", zap.Error(err)) + return errors.NewInternalError("Failed to update session").WithInternal(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.NewInternalError("Failed to get affected rows").WithInternal(err) + } + + if rowsAffected == 0 { + return errors.NewNotFoundError("Session not found") + } + + r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String())) + return nil +} + +// UpdateActivity updates the last activity timestamp for a session +func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error { + r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String())) + + query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1` + + result, err := r.db.ExecContext(ctx, query, sessionID) + if err != nil { + r.logger.Error("Failed to update session activity", zap.Error(err)) + return errors.NewInternalError("Failed to update session activity").WithInternal(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.NewInternalError("Failed to get affected rows").WithInternal(err) + } + + if rowsAffected == 0 { + return errors.NewNotFoundError("Session not found") + } + + return nil +} + +// Revoke revokes a session +func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error { + r.logger.Debug("Revoking session", + zap.String("session_id", sessionID.String()), + zap.String("revoked_by", revokedBy)) + + query := ` + UPDATE user_sessions + SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW() + WHERE id = $3` + + result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID) + if err != nil { + r.logger.Error("Failed to revoke session", zap.Error(err)) + return errors.NewInternalError("Failed to revoke session").WithInternal(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.NewInternalError("Failed to get affected rows").WithInternal(err) + } + + if rowsAffected == 0 { + return errors.NewNotFoundError("Session not found") + } + + r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String())) + return nil +} + +// RevokeAllByUser revokes all sessions for a user +func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error { + r.logger.Debug("Revoking all sessions for user", + zap.String("user_id", userID), + zap.String("revoked_by", revokedBy)) + + query := ` + UPDATE user_sessions + SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW() + WHERE user_id = $3 AND status = $4` + + result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive) + if err != nil { + r.logger.Error("Failed to revoke user sessions", zap.Error(err)) + return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.NewInternalError("Failed to get affected rows").WithInternal(err) + } + + r.logger.Debug("User sessions revoked", + zap.String("user_id", userID), + zap.Int64("sessions_revoked", rowsAffected)) + return nil +} + +// RevokeAllByUserAndApp revokes all sessions for a user and application +func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error { + r.logger.Debug("Revoking all sessions for user and app", + zap.String("user_id", userID), + zap.String("app_id", appID), + zap.String("revoked_by", revokedBy)) + + query := ` + UPDATE user_sessions + SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW() + WHERE user_id = $3 AND app_id = $4 AND status = $5` + + result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive) + if err != nil { + r.logger.Error("Failed to revoke user app sessions", zap.Error(err)) + return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.NewInternalError("Failed to get affected rows").WithInternal(err) + } + + r.logger.Debug("User app sessions revoked", + zap.String("user_id", userID), + zap.String("app_id", appID), + zap.Int64("sessions_revoked", rowsAffected)) + return nil +} + +// ExpireOldSessions marks expired sessions as expired +func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) { + r.logger.Debug("Expiring old sessions") + + query := ` + UPDATE user_sessions + SET status = $1, updated_at = NOW() + WHERE expires_at < NOW() AND status = $2` + + result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive) + if err != nil { + r.logger.Error("Failed to expire old sessions", zap.Error(err)) + return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err) + } + + r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected)) + return int(rowsAffected), nil +} + +// DeleteExpiredSessions removes expired sessions older than the specified duration +func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) { + r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan)) + + cutoffTime := time.Now().Add(-olderThan) + query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2` + + result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime) + if err != nil { + r.logger.Error("Failed to delete expired sessions", zap.Error(err)) + return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err) + } + + r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected)) + return int(rowsAffected), nil +} + +// Exists checks if a session exists +func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) { + r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String())) + + query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)` + + var exists bool + err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists) + if err != nil { + r.logger.Error("Failed to check session existence", zap.Error(err)) + return false, errors.NewInternalError("Failed to check session existence").WithInternal(err) + } + + return exists, nil +} + +// GetSessionCount returns the total number of sessions for a user +func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) { + r.logger.Debug("Getting session count for user", zap.String("user_id", userID)) + + query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1` + + var count int + err := r.db.QueryRowContext(ctx, query, userID).Scan(&count) + if err != nil { + r.logger.Error("Failed to get session count", zap.Error(err)) + return 0, errors.NewInternalError("Failed to get session count").WithInternal(err) + } + + return count, nil +} + +// GetActiveSessionCount returns the number of active sessions for a user +func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) { + r.logger.Debug("Getting active session count for user", zap.String("user_id", userID)) + + query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()` + + var count int + err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count) + if err != nil { + r.logger.Error("Failed to get active session count", zap.Error(err)) + return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err) + } + + return count, nil +} + +// scanSessions is a helper method to scan multiple sessions from query results +func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) { + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + r.logger.Error("Failed to execute session query", zap.Error(err)) + return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err) + } + defer rows.Close() + + var sessions []*domain.UserSession + for rows.Next() { + var session domain.UserSession + var metadataJSON []byte + var revokedAt sql.NullTime + var revokedBy sql.NullString + + err := rows.Scan( + &session.ID, + &session.UserID, + &session.AppID, + &session.SessionType, + &session.Status, + &session.AccessToken, + &session.RefreshToken, + &session.IDToken, + &session.IPAddress, + &session.UserAgent, + &session.LastActivity, + &session.ExpiresAt, + &session.CreatedAt, + &session.UpdatedAt, + &revokedAt, + &revokedBy, + &metadataJSON, + ) + + if err != nil { + r.logger.Error("Failed to scan session row", zap.Error(err)) + return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err) + } + + // Handle nullable fields + if revokedAt.Valid { + session.RevokedAt = &revokedAt.Time + } + if revokedBy.Valid { + session.RevokedBy = &revokedBy.String + } + + // Deserialize metadata + if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil { + r.logger.Warn("Failed to deserialize session metadata", zap.Error(err)) + session.Metadata = domain.SessionMetadata{} // Use empty metadata on error + } + + sessions = append(sessions, &session) + } + + if err := rows.Err(); err != nil { + r.logger.Error("Error iterating session rows", zap.Error(err)) + return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err) + } + + return sessions, nil +} diff --git a/internal/services/interfaces.go b/internal/services/interfaces.go index 598cd45..c9a22b2 100644 --- a/internal/services/interfaces.go +++ b/internal/services/interfaces.go @@ -67,3 +67,54 @@ type AuthenticationService interface { // RefreshJWTToken refreshes an existing JWT token RefreshJWTToken(ctx context.Context, tokenString string, newExpiration time.Time) (string, error) } + +// SessionService defines the interface for session management business logic +type SessionService interface { + // CreateSession creates a new user session + CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error) + + // GetSession retrieves a session by its ID + GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) + + // GetUserSessions retrieves all sessions for a user + GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) + + // GetUserAppSessions retrieves sessions for a specific user and application + GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) + + // GetActiveSessions retrieves all active sessions for a user + GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) + + // ListSessions retrieves sessions with filtering and pagination + ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) + + // UpdateSession updates an existing session + UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error + + // UpdateSessionActivity updates the last activity timestamp for a session + UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error + + // RevokeSession revokes a specific session + RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error + + // RevokeUserSessions revokes all sessions for a user + RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error + + // RevokeUserAppSessions revokes all sessions for a user and application + RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error + + // ValidateSession validates if a session is active and valid + ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) + + // RefreshSession refreshes a session's expiration time + RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error + + // CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones + CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error) + + // GetSessionStats returns session statistics for a user + GetSessionStats(ctx context.Context, userID string) (total int, active int, err error) + + // CreateOAuth2Session creates a session from OAuth2 authentication flow + CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error) +} diff --git a/internal/services/session_service.go b/internal/services/session_service.go new file mode 100644 index 0000000..2f5215c --- /dev/null +++ b/internal/services/session_service.go @@ -0,0 +1,414 @@ +package services + +import ( + "context" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + + "github.com/kms/api-key-service/internal/config" + "github.com/kms/api-key-service/internal/domain" + "github.com/kms/api-key-service/internal/errors" + "github.com/kms/api-key-service/internal/repository" +) + +// sessionService implements the SessionService interface +type sessionService struct { + sessionRepo repository.SessionRepository + appRepo repository.ApplicationRepository + config config.ConfigProvider + logger *zap.Logger +} + +// NewSessionService creates a new session service +func NewSessionService( + sessionRepo repository.SessionRepository, + appRepo repository.ApplicationRepository, + config config.ConfigProvider, + logger *zap.Logger, +) SessionService { + return &sessionService{ + sessionRepo: sessionRepo, + appRepo: appRepo, + config: config, + logger: logger, + } +} + +// CreateSession creates a new user session +func (s *sessionService) CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error) { + s.logger.Debug("Creating new session", + zap.String("user_id", req.UserID), + zap.String("app_id", req.AppID), + zap.String("session_type", string(req.SessionType))) + + // Validate application exists + app, err := s.appRepo.GetByID(ctx, req.AppID) + if err != nil { + if errors.IsNotFound(err) { + return nil, errors.NewValidationError("Application not found") + } + return nil, err + } + + // Check if application supports user tokens + supportsUser := false + for _, appType := range app.Type { + if appType == domain.ApplicationTypeUser { + supportsUser = true + break + } + } + if !supportsUser { + return nil, errors.NewValidationError("Application does not support user sessions") + } + + // Create session object + session := &domain.UserSession{ + ID: uuid.New(), + UserID: req.UserID, + AppID: req.AppID, + SessionType: req.SessionType, + Status: domain.SessionStatusActive, + IPAddress: req.IPAddress, + UserAgent: req.UserAgent, + ExpiresAt: req.ExpiresAt, + Metadata: domain.SessionMetadata{ + TenantID: req.TenantID, + Permissions: req.Permissions, + Claims: req.Claims, + LoginMethod: "oauth2", + }, + } + + // Create session in repository + if err := s.sessionRepo.Create(ctx, session); err != nil { + s.logger.Error("Failed to create session", zap.Error(err)) + return nil, err + } + + s.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String())) + return session, nil +} + +// GetSession retrieves a session by its ID +func (s *sessionService) GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) { + s.logger.Debug("Getting session", zap.String("session_id", sessionID.String())) + + session, err := s.sessionRepo.GetByID(ctx, sessionID) + if err != nil { + return nil, err + } + + return session, nil +} + +// GetUserSessions retrieves all sessions for a user +func (s *sessionService) GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) { + s.logger.Debug("Getting user sessions", zap.String("user_id", userID)) + + sessions, err := s.sessionRepo.GetByUserID(ctx, userID) + if err != nil { + return nil, err + } + + return sessions, nil +} + +// GetUserAppSessions retrieves sessions for a specific user and application +func (s *sessionService) GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) { + s.logger.Debug("Getting user app sessions", + zap.String("user_id", userID), + zap.String("app_id", appID)) + + sessions, err := s.sessionRepo.GetByUserAndApp(ctx, userID, appID) + if err != nil { + return nil, err + } + + return sessions, nil +} + +// GetActiveSessions retrieves all active sessions for a user +func (s *sessionService) GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) { + s.logger.Debug("Getting active sessions", zap.String("user_id", userID)) + + sessions, err := s.sessionRepo.GetActiveByUserID(ctx, userID) + if err != nil { + return nil, err + } + + return sessions, nil +} + +// ListSessions retrieves sessions with filtering and pagination +func (s *sessionService) ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) { + s.logger.Debug("Listing sessions", + zap.String("user_id", req.UserID), + zap.String("app_id", req.AppID), + zap.Int("limit", req.Limit), + zap.Int("offset", req.Offset)) + + // Set default pagination if not provided + if req.Limit <= 0 { + req.Limit = 50 + } + if req.Limit > 100 { + req.Limit = 100 + } + + response, err := s.sessionRepo.List(ctx, req) + if err != nil { + return nil, err + } + + return response, nil +} + +// UpdateSession updates an existing session +func (s *sessionService) UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error { + s.logger.Debug("Updating session", zap.String("session_id", sessionID.String())) + + // Validate session exists + _, err := s.sessionRepo.GetByID(ctx, sessionID) + if err != nil { + return err + } + + // Update session + if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil { + return err + } + + s.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String())) + return nil +} + +// UpdateSessionActivity updates the last activity timestamp for a session +func (s *sessionService) UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error { + s.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String())) + + if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil { + return err + } + + return nil +} + +// RevokeSession revokes a specific session +func (s *sessionService) RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error { + s.logger.Debug("Revoking session", + zap.String("session_id", sessionID.String()), + zap.String("revoked_by", revokedBy)) + + // Validate session exists and is active + session, err := s.sessionRepo.GetByID(ctx, sessionID) + if err != nil { + return err + } + + if session.Status != domain.SessionStatusActive { + return errors.NewValidationError("Session is not active") + } + + // Revoke session + if err := s.sessionRepo.Revoke(ctx, sessionID, revokedBy); err != nil { + return err + } + + s.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String())) + return nil +} + +// RevokeUserSessions revokes all sessions for a user +func (s *sessionService) RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error { + s.logger.Debug("Revoking user sessions", + zap.String("user_id", userID), + zap.String("revoked_by", revokedBy)) + + if err := s.sessionRepo.RevokeAllByUser(ctx, userID, revokedBy); err != nil { + return err + } + + s.logger.Debug("User sessions revoked successfully", zap.String("user_id", userID)) + return nil +} + +// RevokeUserAppSessions revokes all sessions for a user and application +func (s *sessionService) RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error { + s.logger.Debug("Revoking user app sessions", + zap.String("user_id", userID), + zap.String("app_id", appID), + zap.String("revoked_by", revokedBy)) + + if err := s.sessionRepo.RevokeAllByUserAndApp(ctx, userID, appID, revokedBy); err != nil { + return err + } + + s.logger.Debug("User app sessions revoked successfully", + zap.String("user_id", userID), + zap.String("app_id", appID)) + return nil +} + +// ValidateSession validates if a session is active and valid +func (s *sessionService) ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) { + s.logger.Debug("Validating session", zap.String("session_id", sessionID.String())) + + session, err := s.sessionRepo.GetByID(ctx, sessionID) + if err != nil { + return nil, err + } + + // Check if session is active + if !session.IsActive() { + if session.IsExpired() { + return nil, errors.NewAuthenticationError("Session has expired") + } + if session.IsRevoked() { + return nil, errors.NewAuthenticationError("Session has been revoked") + } + return nil, errors.NewAuthenticationError("Session is not active") + } + + // Update last activity + if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil { + s.logger.Warn("Failed to update session activity", zap.Error(err)) + // Don't fail validation if we can't update activity + } + + s.logger.Debug("Session validated successfully", zap.String("session_id", sessionID.String())) + return session, nil +} + +// RefreshSession refreshes a session's expiration time +func (s *sessionService) RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error { + s.logger.Debug("Refreshing session", + zap.String("session_id", sessionID.String()), + zap.Time("new_expiration", newExpiration)) + + // Validate session exists and is active + session, err := s.sessionRepo.GetByID(ctx, sessionID) + if err != nil { + return err + } + + if !session.IsActive() { + return errors.NewValidationError("Cannot refresh inactive session") + } + + // Update expiration + updates := &domain.UpdateSessionRequest{ + ExpiresAt: &newExpiration, + } + + if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil { + return err + } + + s.logger.Debug("Session refreshed successfully", zap.String("session_id", sessionID.String())) + return nil +} + +// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones +func (s *sessionService) CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error) { + s.logger.Debug("Cleaning up expired sessions") + + // Mark expired sessions + expired, err = s.sessionRepo.ExpireOldSessions(ctx) + if err != nil { + s.logger.Error("Failed to expire old sessions", zap.Error(err)) + return 0, 0, err + } + + // Delete old expired sessions if requested + if deleteOlderThan != nil { + deleted, err = s.sessionRepo.DeleteExpiredSessions(ctx, *deleteOlderThan) + if err != nil { + s.logger.Error("Failed to delete expired sessions", zap.Error(err)) + return expired, 0, err + } + } + + s.logger.Debug("Session cleanup completed", + zap.Int("expired", expired), + zap.Int("deleted", deleted)) + + return expired, deleted, nil +} + +// GetSessionStats returns session statistics for a user +func (s *sessionService) GetSessionStats(ctx context.Context, userID string) (total int, active int, err error) { + s.logger.Debug("Getting session stats", zap.String("user_id", userID)) + + total, err = s.sessionRepo.GetSessionCount(ctx, userID) + if err != nil { + return 0, 0, err + } + + active, err = s.sessionRepo.GetActiveSessionCount(ctx, userID) + if err != nil { + return 0, 0, err + } + + return total, active, nil +} + +// CreateOAuth2Session creates a session from OAuth2 authentication flow +func (s *sessionService) CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error) { + s.logger.Debug("Creating OAuth2 session", + zap.String("user_id", userID), + zap.String("app_id", appID), + zap.String("session_type", string(sessionType))) + + // Validate application exists + app, err := s.appRepo.GetByID(ctx, appID) + if err != nil { + if errors.IsNotFound(err) { + return nil, errors.NewValidationError("Application not found") + } + return nil, err + } + + // Calculate expiration based on token response + expiresAt := time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second) + + // Use application's max token duration if shorter + maxExpiration := time.Now().Add(app.MaxTokenDuration) + if expiresAt.After(maxExpiration) { + expiresAt = maxExpiration + } + + // Create session object + session := &domain.UserSession{ + ID: uuid.New(), + UserID: userID, + AppID: appID, + SessionType: sessionType, + Status: domain.SessionStatusActive, + AccessToken: tokenResponse.AccessToken, // In production, encrypt this + RefreshToken: tokenResponse.RefreshToken, // In production, encrypt this + IDToken: tokenResponse.IDToken, // In production, encrypt this + IPAddress: ipAddress, + UserAgent: userAgent, + ExpiresAt: expiresAt, + Metadata: domain.SessionMetadata{ + LoginMethod: "oauth2", + Claims: map[string]string{ + "sub": userInfo.Sub, + "email": userInfo.Email, + "name": userInfo.Name, + }, + }, + } + + // Create session in repository + if err := s.sessionRepo.Create(ctx, session); err != nil { + s.logger.Error("Failed to create OAuth2 session", zap.Error(err)) + return nil, err + } + + s.logger.Debug("OAuth2 session created successfully", zap.String("session_id", session.ID.String())) + return session, nil +} diff --git a/migrations/002_user_sessions.down.sql b/migrations/002_user_sessions.down.sql new file mode 100644 index 0000000..df3d4f3 --- /dev/null +++ b/migrations/002_user_sessions.down.sql @@ -0,0 +1,14 @@ +-- Drop user_sessions table and related objects +DROP INDEX IF EXISTS idx_user_sessions_tenant; +DROP INDEX IF EXISTS idx_user_sessions_metadata; +DROP INDEX IF EXISTS idx_user_sessions_active_expires; +DROP INDEX IF EXISTS idx_user_sessions_active; +DROP INDEX IF EXISTS idx_user_sessions_created_at; +DROP INDEX IF EXISTS idx_user_sessions_last_activity; +DROP INDEX IF EXISTS idx_user_sessions_expires_at; +DROP INDEX IF EXISTS idx_user_sessions_status; +DROP INDEX IF EXISTS idx_user_sessions_user_app; +DROP INDEX IF EXISTS idx_user_sessions_app_id; +DROP INDEX IF EXISTS idx_user_sessions_user_id; + +DROP TABLE IF EXISTS user_sessions; diff --git a/migrations/002_user_sessions.up.sql b/migrations/002_user_sessions.up.sql new file mode 100644 index 0000000..5a73a6c --- /dev/null +++ b/migrations/002_user_sessions.up.sql @@ -0,0 +1,60 @@ +-- Create user_sessions table for session management +CREATE TABLE IF NOT EXISTS user_sessions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id VARCHAR(255) NOT NULL, + app_id VARCHAR(255) NOT NULL, + session_type VARCHAR(20) NOT NULL CHECK (session_type IN ('web', 'mobile', 'api')), + status VARCHAR(20) NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'expired', 'revoked', 'suspended')), + access_token TEXT, + refresh_token TEXT, + id_token TEXT, + ip_address INET, + user_agent TEXT, + last_activity TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + revoked_at TIMESTAMP WITH TIME ZONE, + revoked_by VARCHAR(255), + metadata JSONB DEFAULT '{}', + + -- Foreign key constraint to applications table + CONSTRAINT fk_user_sessions_app_id FOREIGN KEY (app_id) REFERENCES applications(app_id) ON DELETE CASCADE +); + +-- Create indexes for performance +CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id); +CREATE INDEX IF NOT EXISTS idx_user_sessions_app_id ON user_sessions(app_id); +CREATE INDEX IF NOT EXISTS idx_user_sessions_user_app ON user_sessions(user_id, app_id); +CREATE INDEX IF NOT EXISTS idx_user_sessions_status ON user_sessions(status); +CREATE INDEX IF NOT EXISTS idx_user_sessions_expires_at ON user_sessions(expires_at); +CREATE INDEX IF NOT EXISTS idx_user_sessions_last_activity ON user_sessions(last_activity); +CREATE INDEX IF NOT EXISTS idx_user_sessions_created_at ON user_sessions(created_at); + +-- Create partial indexes for active sessions (most common queries) +CREATE INDEX IF NOT EXISTS idx_user_sessions_active ON user_sessions(user_id, app_id) WHERE status = 'active'; +CREATE INDEX IF NOT EXISTS idx_user_sessions_active_expires ON user_sessions(user_id, expires_at) WHERE status = 'active'; + +-- Create GIN index for metadata JSONB queries +CREATE INDEX IF NOT EXISTS idx_user_sessions_metadata ON user_sessions USING GIN (metadata); + +-- Create index for tenant-based queries (if using multi-tenancy) +CREATE INDEX IF NOT EXISTS idx_user_sessions_tenant ON user_sessions((metadata->>'tenant_id')) WHERE metadata->>'tenant_id' IS NOT NULL; + +-- Add comments for documentation +COMMENT ON TABLE user_sessions IS 'Stores user session information for authentication and session management'; +COMMENT ON COLUMN user_sessions.id IS 'Unique session identifier'; +COMMENT ON COLUMN user_sessions.user_id IS 'User identifier (email, username, or external ID)'; +COMMENT ON COLUMN user_sessions.app_id IS 'Application identifier this session belongs to'; +COMMENT ON COLUMN user_sessions.session_type IS 'Type of session: web, mobile, or api'; +COMMENT ON COLUMN user_sessions.status IS 'Current session status: active, expired, revoked, or suspended'; +COMMENT ON COLUMN user_sessions.access_token IS 'OAuth2/OIDC access token (encrypted/hashed)'; +COMMENT ON COLUMN user_sessions.refresh_token IS 'OAuth2/OIDC refresh token (encrypted/hashed)'; +COMMENT ON COLUMN user_sessions.id_token IS 'OIDC ID token (encrypted/hashed)'; +COMMENT ON COLUMN user_sessions.ip_address IS 'IP address of the client when session was created'; +COMMENT ON COLUMN user_sessions.user_agent IS 'User agent string of the client'; +COMMENT ON COLUMN user_sessions.last_activity IS 'Timestamp of last session activity'; +COMMENT ON COLUMN user_sessions.expires_at IS 'When the session expires'; +COMMENT ON COLUMN user_sessions.revoked_at IS 'When the session was revoked (if applicable)'; +COMMENT ON COLUMN user_sessions.revoked_by IS 'Who revoked the session (user ID or system)'; +COMMENT ON COLUMN user_sessions.metadata IS 'Additional session metadata (device info, location, etc.)'; diff --git a/test/saml_test.go b/test/saml_test.go new file mode 100644 index 0000000..78cc304 --- /dev/null +++ b/test/saml_test.go @@ -0,0 +1,532 @@ +package test + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "github.com/kms/api-key-service/internal/auth" + "github.com/kms/api-key-service/internal/domain" +) + +// mockSAMLMetadata returns a mock SAML IdP metadata XML +func mockSAMLMetadata() string { + return ` + + + + + + MIICertificateData + + + + + + +` +} + +// mockSAMLResponse returns a mock SAML response XML with current timestamps +func mockSAMLResponse() string { + now := time.Now().UTC() + issueInstant := now.Format(time.RFC3339) + notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339) + notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339) + + return fmt.Sprintf(` + + https://idp.example.com + + + + + https://idp.example.com + + user@example.com + + + + + + + https://sp.example.com + + + + + user@example.com + + + Test User + + + Test + + + User + + + admin,user + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + +`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant) +} + +func TestSAMLProvider_GetMetadata(t *testing.T) { + tests := []struct { + name string + metadataURL string + serverResponse string + serverStatus int + expectError bool + errorContains string + }{ + { + name: "successful metadata fetch", + metadataURL: "https://idp.example.com/.well-known/saml-metadata", + serverResponse: mockSAMLMetadata(), + serverStatus: http.StatusOK, + expectError: false, + }, + { + name: "missing metadata URL", + metadataURL: "", + expectError: true, + errorContains: "SAML_IDP_METADATA_URL not configured", + }, + { + name: "server error", + metadataURL: "https://idp.example.com/.well-known/saml-metadata", + serverStatus: http.StatusInternalServerError, + expectError: true, + errorContains: "returned status 500", + }, + { + name: "invalid XML", + metadataURL: "https://idp.example.com/.well-known/saml-metadata", + serverResponse: "invalid xml", + serverStatus: http.StatusOK, + expectError: true, + errorContains: "Failed to parse SAML metadata", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock HTTP server + var server *httptest.Server + if tt.metadataURL != "" && tt.serverStatus > 0 { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.serverStatus) + if tt.serverResponse != "" { + w.Write([]byte(tt.serverResponse)) + } + })) + defer server.Close() + tt.metadataURL = server.URL + } + + // Create config + cfg := NewTestConfig() + cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL + + // Create SAML provider + logger := zaptest.NewLogger(t) + provider, err := auth.NewSAMLProvider(cfg, logger) + require.NoError(t, err) + + // Test GetMetadata + ctx := context.Background() + metadata, err := provider.GetMetadata(ctx) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, metadata) + } else { + assert.NoError(t, err) + assert.NotNil(t, metadata) + assert.Equal(t, "https://idp.example.com", metadata.EntityID) + assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService) + } + }) + } +} + +func TestSAMLProvider_GenerateAuthRequest(t *testing.T) { + tests := []struct { + name string + spEntityID string + acsURL string + relayState string + expectError bool + errorContains string + }{ + { + name: "successful auth request generation", + spEntityID: "https://sp.example.com", + acsURL: "https://sp.example.com/acs", + relayState: "test-relay-state", + }, + { + name: "missing SP entity ID", + spEntityID: "", + acsURL: "https://sp.example.com/acs", + expectError: true, + errorContains: "SAML_SP_ENTITY_ID not configured", + }, + { + name: "missing ACS URL", + spEntityID: "https://sp.example.com", + acsURL: "", + expectError: true, + errorContains: "SAML_SP_ACS_URL not configured", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock HTTP server for metadata + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(mockSAMLMetadata())) + })) + defer server.Close() + + // Create config + cfg := NewTestConfig() + cfg.values["SAML_IDP_METADATA_URL"] = server.URL + cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID + cfg.values["SAML_SP_ACS_URL"] = tt.acsURL + + // Create SAML provider + logger := zaptest.NewLogger(t) + provider, err := auth.NewSAMLProvider(cfg, logger) + require.NoError(t, err) + + // Test GenerateAuthRequest + ctx := context.Background() + authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Empty(t, authURL) + assert.Empty(t, requestID) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, authURL) + assert.NotEmpty(t, requestID) + assert.Contains(t, authURL, "https://idp.example.com/sso") + assert.Contains(t, authURL, "SAMLRequest=") + if tt.relayState != "" { + assert.Contains(t, authURL, "RelayState="+tt.relayState) + } + } + }) + } +} + +func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) { + tests := []struct { + name string + samlResponse string + expectedRequestID string + spEntityID string + expectError bool + errorContains string + expectedUserID string + expectedEmail string + expectedName string + expectedRoles []string + }{ + { + name: "successful SAML response processing", + samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())), + expectedRequestID: "_request_id", + spEntityID: "https://sp.example.com", + expectedUserID: "user@example.com", + expectedEmail: "user@example.com", + expectedName: "Test User", + expectedRoles: []string{"admin", "user"}, + }, + { + name: "invalid base64 encoding", + samlResponse: "invalid-base64", + expectError: true, + errorContains: "Failed to decode SAML response", + }, + { + name: "invalid XML", + samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")), + expectError: true, + errorContains: "Failed to parse SAML response", + }, + { + name: "audience mismatch", + samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())), + spEntityID: "https://wrong-sp.example.com", + expectError: true, + errorContains: "audience mismatch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create config + cfg := NewTestConfig() + cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID + + // Create SAML provider + logger := zaptest.NewLogger(t) + provider, err := auth.NewSAMLProvider(cfg, logger) + require.NoError(t, err) + + // Test ProcessSAMLResponse + ctx := context.Background() + authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, authContext) + } else { + assert.NoError(t, err) + assert.NotNil(t, authContext) + assert.Equal(t, tt.expectedUserID, authContext.UserID) + assert.Equal(t, domain.TokenTypeUser, authContext.TokenType) + + // Check claims + if tt.expectedEmail != "" { + assert.Equal(t, tt.expectedEmail, authContext.Claims["email"]) + } + if tt.expectedName != "" { + assert.Equal(t, tt.expectedName, authContext.Claims["name"]) + } + + // Check permissions/roles + if len(tt.expectedRoles) > 0 { + assert.Equal(t, tt.expectedRoles, authContext.Permissions) + } + } + }) + } +} + +func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) { + tests := []struct { + name string + spEntityID string + acsURL string + expectError bool + errorContains string + }{ + { + name: "successful SP metadata generation", + spEntityID: "https://sp.example.com", + acsURL: "https://sp.example.com/acs", + }, + { + name: "missing SP entity ID", + spEntityID: "", + acsURL: "https://sp.example.com/acs", + expectError: true, + errorContains: "SAML_SP_ENTITY_ID not configured", + }, + { + name: "missing ACS URL", + spEntityID: "https://sp.example.com", + acsURL: "", + expectError: true, + errorContains: "SAML_SP_ACS_URL not configured", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create config + cfg := NewTestConfig() + cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID + cfg.values["SAML_SP_ACS_URL"] = tt.acsURL + + // Create SAML provider + logger := zaptest.NewLogger(t) + provider, err := auth.NewSAMLProvider(cfg, logger) + require.NoError(t, err) + + // Test GenerateServiceProviderMetadata + metadata, err := provider.GenerateServiceProviderMetadata() + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Empty(t, metadata) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, metadata) + assert.Contains(t, metadata, tt.spEntityID) + assert.Contains(t, metadata, tt.acsURL) + assert.Contains(t, metadata, "EntityDescriptor") + assert.Contains(t, metadata, "SPSSODescriptor") + } + }) + } +} + +// Benchmark tests for SAML operations +func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) { + // Create config + cfg := NewTestConfig() + cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com" + + // Create SAML provider + logger := zaptest.NewLogger(b) + provider, err := auth.NewSAMLProvider(cfg, logger) + require.NoError(b, err) + + // Prepare SAML response + samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())) + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id") + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) { + // Create mock HTTP server for metadata + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(mockSAMLMetadata())) + })) + defer server.Close() + + // Create config + cfg := NewTestConfig() + cfg.values["SAML_IDP_METADATA_URL"] = server.URL + cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com" + cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs" + + // Create SAML provider + logger := zaptest.NewLogger(b) + provider, err := auth.NewSAMLProvider(cfg, logger) + require.NoError(b, err) + + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state") + if err != nil { + b.Fatal(err) + } + } +} + +// Test helper functions +func TestSAMLResponseValidation(t *testing.T) { + // Test various SAML response validation scenarios + tests := []struct { + name string + modifyXML func(string) string + expectError bool + errorContains string + }{ + { + name: "expired assertion", + modifyXML: func(xml string) string { + // Replace all NotOnOrAfter timestamps with past time + pastTime := "2020-01-01T13:00:00Z" + re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`) + return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`) + }, + expectError: true, + errorContains: "assertion has expired", + }, + { + name: "assertion not yet valid", + modifyXML: func(xml string) string { + // Replace all NotBefore timestamps with future time + futureTime := "2030-01-01T11:55:00Z" + re := regexp.MustCompile(`NotBefore="[^"]*"`) + return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`) + }, + expectError: true, + errorContains: "assertion not yet valid", + }, + { + name: "failed status", + modifyXML: func(xml string) string { + return strings.ReplaceAll(xml, + "urn:oasis:names:tc:SAML:2.0:status:Success", + "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed") + }, + expectError: true, + errorContains: "SAML authentication failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create config + cfg := NewTestConfig() + cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com" + + // Create SAML provider + logger := zaptest.NewLogger(t) + provider, err := auth.NewSAMLProvider(cfg, logger) + require.NoError(t, err) + + // Modify SAML response + modifiedXML := tt.modifyXML(mockSAMLResponse()) + samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML)) + + // Test ProcessSAMLResponse + ctx := context.Background() + authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id") + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, authContext) + } else { + assert.NoError(t, err) + assert.NotNil(t, authContext) + } + }) + } +} diff --git a/test/token_repository_test.go b/test/token_repository_test.go new file mode 100644 index 0000000..d9c65fb --- /dev/null +++ b/test/token_repository_test.go @@ -0,0 +1,705 @@ +package test + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kms/api-key-service/internal/domain" + "github.com/kms/api-key-service/internal/repository" + "github.com/kms/api-key-service/internal/repository/postgres" +) + +// SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing +type SQLMockDatabaseProvider struct { + db *sql.DB +} + +func (m *SQLMockDatabaseProvider) GetDB() interface{} { + return m.db +} + +func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error { + return m.db.PingContext(ctx) +} + +func (m *SQLMockDatabaseProvider) Close() error { + return m.db.Close() +} + +func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) { + tx, err := m.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &SQLMockTransactionProvider{tx: tx}, nil +} + +func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error { + return nil +} + +// SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing +type SQLMockTransactionProvider struct { + tx *sql.Tx +} + +func (m *SQLMockTransactionProvider) Commit() error { + return m.tx.Commit() +} + +func (m *SQLMockTransactionProvider) Rollback() error { + return m.tx.Rollback() +} + +func (m *SQLMockTransactionProvider) GetTx() interface{} { + return m.tx +} + +func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + + mockDB := &SQLMockDatabaseProvider{db: db} + repo := postgres.NewStaticTokenRepository(mockDB) + + cleanup := func() { + db.Close() + } + + return repo.(*postgres.StaticTokenRepository), mock, cleanup +} + +func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) { + db, mock, err := sqlmock.New() + if err != nil { + b.Fatal(err) + } + + mockDB := &SQLMockDatabaseProvider{db: db} + repo := postgres.NewStaticTokenRepository(mockDB) + + cleanup := func() { + db.Close() + } + + return repo.(*postgres.StaticTokenRepository), mock, cleanup +} + +func TestStaticTokenRepository_Create(t *testing.T) { + tests := []struct { + name string + token *domain.StaticToken + setupMock func(sqlmock.Sqlmock) + expectError bool + errorMsg string + }{ + { + name: "successful creation", + token: &domain.StaticToken{ + ID: uuid.New(), + AppID: "test-app", + Owner: domain.Owner{ + Type: domain.OwnerTypeIndividual, + Name: "test-user", + Owner: "test-owner", + }, + KeyHash: "test-hash", + Type: "hmac", + }, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(`INSERT INTO static_tokens`). + WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + }, + expectError: false, + }, + { + name: "database error", + token: &domain.StaticToken{ + ID: uuid.New(), + AppID: "test-app", + Owner: domain.Owner{ + Type: domain.OwnerTypeIndividual, + Name: "test-user", + Owner: "test-owner", + }, + KeyHash: "test-hash", + Type: "hmac", + }, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(`INSERT INTO static_tokens`). + WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnError(sql.ErrConnDone) + }, + expectError: true, + errorMsg: "failed to create static token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, mock, cleanup := setupTokenRepositoryTest(t) + defer cleanup() + + tt.setupMock(mock) + + ctx := context.Background() + err := repo.Create(ctx, tt.token) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + assert.NotZero(t, tt.token.CreatedAt) + assert.NotZero(t, tt.token.UpdatedAt) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestStaticTokenRepository_GetByID(t *testing.T) { + tokenID := uuid.New() + now := time.Now() + + tests := []struct { + name string + tokenID uuid.UUID + setupMock func(sqlmock.Sqlmock) + expectError bool + errorMsg string + expectedToken *domain.StaticToken + }{ + { + name: "successful retrieval", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{ + "id", "app_id", "owner_type", "owner_name", "owner_owner", + "key_hash", "type", "created_at", "updated_at", + }).AddRow( + tokenID, "test-app", "individual", "test-user", "test-owner", + "test-hash", "user", now, now, + ) + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnRows(rows) + }, + expectError: false, + expectedToken: &domain.StaticToken{ + ID: tokenID, + AppID: "test-app", + Owner: domain.Owner{ + Type: domain.OwnerTypeIndividual, + Name: "test-user", + Owner: "test-owner", + }, + KeyHash: "test-hash", + Type: string(domain.TokenTypeUser), + CreatedAt: now, + UpdatedAt: now, + }, + }, + { + name: "token not found", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnError(sql.ErrNoRows) + }, + expectError: true, + errorMsg: "not found", + }, + { + name: "database error", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnError(sql.ErrConnDone) + }, + expectError: true, + errorMsg: "failed to get static token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, mock, cleanup := setupTokenRepositoryTest(t) + defer cleanup() + + tt.setupMock(mock) + + ctx := context.Background() + token, err := repo.GetByID(ctx, tt.tokenID) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, token) + } else { + assert.NoError(t, err) + assert.NotNil(t, token) + assert.Equal(t, tt.expectedToken.ID, token.ID) + assert.Equal(t, tt.expectedToken.AppID, token.AppID) + assert.Equal(t, tt.expectedToken.Owner, token.Owner) + assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash) + assert.Equal(t, tt.expectedToken.Type, token.Type) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestStaticTokenRepository_GetByKeyHash(t *testing.T) { + tokenID := uuid.New() + now := time.Now() + keyHash := "test-hash" + + tests := []struct { + name string + keyHash string + setupMock func(sqlmock.Sqlmock) + expectError bool + errorMsg string + expectedToken *domain.StaticToken + }{ + { + name: "successful retrieval", + keyHash: keyHash, + setupMock: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{ + "id", "app_id", "owner_type", "owner_name", "owner_owner", + "key_hash", "type", "created_at", "updated_at", + }).AddRow( + tokenID, "test-app", "individual", "test-user", "test-owner", + keyHash, "user", now, now, + ) + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`). + WithArgs(keyHash). + WillReturnRows(rows) + }, + expectError: false, + expectedToken: &domain.StaticToken{ + ID: tokenID, + AppID: "test-app", + Owner: domain.Owner{ + Type: domain.OwnerTypeIndividual, + Name: "test-user", + Owner: "test-owner", + }, + KeyHash: keyHash, + Type: string(domain.TokenTypeUser), + CreatedAt: now, + UpdatedAt: now, + }, + }, + { + name: "token not found", + keyHash: keyHash, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`). + WithArgs(keyHash). + WillReturnError(sql.ErrNoRows) + }, + expectError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, mock, cleanup := setupTokenRepositoryTest(t) + defer cleanup() + + tt.setupMock(mock) + + ctx := context.Background() + token, err := repo.GetByKeyHash(ctx, tt.keyHash) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, token) + } else { + assert.NoError(t, err) + assert.NotNil(t, token) + assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestStaticTokenRepository_GetByAppID(t *testing.T) { + tokenID1 := uuid.New() + tokenID2 := uuid.New() + now := time.Now() + appID := "test-app" + + tests := []struct { + name string + appID string + setupMock func(sqlmock.Sqlmock) + expectError bool + errorMsg string + expectedCount int + }{ + { + name: "successful retrieval with multiple tokens", + appID: appID, + setupMock: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{ + "id", "app_id", "owner_type", "owner_name", "owner_owner", + "key_hash", "type", "created_at", "updated_at", + }).AddRow( + tokenID1, appID, "user", "test-user1", "test-owner1", + "test-hash1", "user", now, now, + ).AddRow( + tokenID2, appID, "user", "test-user2", "test-owner2", + "test-hash2", "user", now, now, + ) + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`). + WithArgs(appID). + WillReturnRows(rows) + }, + expectError: false, + expectedCount: 2, + }, + { + name: "no tokens found", + appID: appID, + setupMock: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{ + "id", "app_id", "owner_type", "owner_name", "owner_owner", + "key_hash", "type", "created_at", "updated_at", + }) + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`). + WithArgs(appID). + WillReturnRows(rows) + }, + expectError: false, + expectedCount: 0, + }, + { + name: "database error", + appID: appID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`). + WithArgs(appID). + WillReturnError(sql.ErrConnDone) + }, + expectError: true, + errorMsg: "failed to query static tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, mock, cleanup := setupTokenRepositoryTest(t) + defer cleanup() + + tt.setupMock(mock) + + ctx := context.Background() + tokens, err := repo.GetByAppID(ctx, tt.appID) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, tokens) + } else { + assert.NoError(t, err) + assert.Len(t, tokens, tt.expectedCount) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestStaticTokenRepository_List(t *testing.T) { + tokenID := uuid.New() + now := time.Now() + + tests := []struct { + name string + limit int + offset int + setupMock func(sqlmock.Sqlmock) + expectError bool + errorMsg string + expectedCount int + }{ + { + name: "successful list with pagination", + limit: 10, + offset: 0, + setupMock: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{ + "id", "app_id", "owner_type", "owner_name", "owner_owner", + "key_hash", "type", "created_at", "updated_at", + }).AddRow( + tokenID, "test-app", "user", "test-user", "test-owner", + "test-hash", "user", now, now, + ) + mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`). + WithArgs(10, 0). + WillReturnRows(rows) + }, + expectError: false, + expectedCount: 1, + }, + { + name: "database error", + limit: 10, + offset: 0, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`). + WithArgs(10, 0). + WillReturnError(sql.ErrConnDone) + }, + expectError: true, + errorMsg: "failed to query static tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, mock, cleanup := setupTokenRepositoryTest(t) + defer cleanup() + + tt.setupMock(mock) + + ctx := context.Background() + tokens, err := repo.List(ctx, tt.limit, tt.offset) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, tokens) + } else { + assert.NoError(t, err) + assert.Len(t, tokens, tt.expectedCount) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestStaticTokenRepository_Delete(t *testing.T) { + tokenID := uuid.New() + + tests := []struct { + name string + tokenID uuid.UUID + setupMock func(sqlmock.Sqlmock) + expectError bool + errorMsg string + }{ + { + name: "successful deletion", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnResult(sqlmock.NewResult(0, 1)) + }, + expectError: false, + }, + { + name: "token not found", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnResult(sqlmock.NewResult(0, 0)) + }, + expectError: true, + errorMsg: "not found", + }, + { + name: "database error", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnError(sql.ErrConnDone) + }, + expectError: true, + errorMsg: "failed to delete static token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, mock, cleanup := setupTokenRepositoryTest(t) + defer cleanup() + + tt.setupMock(mock) + + ctx := context.Background() + err := repo.Delete(ctx, tt.tokenID) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestStaticTokenRepository_Exists(t *testing.T) { + tokenID := uuid.New() + + tests := []struct { + name string + tokenID uuid.UUID + setupMock func(sqlmock.Sqlmock) + expectError bool + errorMsg string + expectedExists bool + }{ + { + name: "token exists", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"exists"}).AddRow(1) + mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnRows(rows) + }, + expectError: false, + expectedExists: true, + }, + { + name: "token does not exist", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnError(sql.ErrNoRows) + }, + expectError: false, + expectedExists: false, + }, + { + name: "database error", + tokenID: tokenID, + setupMock: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnError(sql.ErrConnDone) + }, + expectError: true, + errorMsg: "failed to check static token existence", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, mock, cleanup := setupTokenRepositoryTest(t) + defer cleanup() + + tt.setupMock(mock) + + ctx := context.Background() + exists, err := repo.Exists(ctx, tt.tokenID) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedExists, exists) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +// Benchmark tests for repository operations +func BenchmarkStaticTokenRepository_Create(b *testing.B) { + repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b) + defer cleanup() + + token := &domain.StaticToken{ + ID: uuid.New(), + AppID: "test-app", + Owner: domain.Owner{ + Type: domain.OwnerTypeIndividual, + Name: "test-user", + Owner: "test-owner", + }, + KeyHash: "test-hash", + Type: string(domain.TokenTypeUser), + } + + // Setup mock expectations for all iterations + for i := 0; i < b.N; i++ { + mock.ExpectExec(`INSERT INTO static_tokens`). + WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + } + + ctx := context.Background() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + token.ID = uuid.New() // Generate new ID for each iteration + err := repo.Create(ctx, token) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkStaticTokenRepository_GetByID(b *testing.B) { + repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b) + defer cleanup() + + tokenID := uuid.New() + now := time.Now() + + // Setup mock expectations for all iterations + for i := 0; i < b.N; i++ { + rows := sqlmock.NewRows([]string{ + "id", "app_id", "owner_type", "owner_name", "owner_owner", + "key_hash", "type", "created_at", "updated_at", + }).AddRow( + tokenID, "test-app", "user", "test-user", "test-owner", + "test-hash", "user", now, now, + ) + mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). + WithArgs(tokenID). + WillReturnRows(rows) + } + + ctx := context.Background() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := repo.GetByID(ctx, tokenID) + if err != nil { + b.Fatal(err) + } + } +}