diff --git a/docs/PRODUCTION_ROADMAP.md b/docs/PRODUCTION_ROADMAP.md
index 54cab77..54a5165 100644
--- a/docs/PRODUCTION_ROADMAP.md
+++ b/docs/PRODUCTION_ROADMAP.md
@@ -64,10 +64,10 @@ This document outlines the complete roadmap for making the API Key Management Se
- [x] Implement authorization code exchange and token refresh
- [x] Add user info retrieval from OAuth2 providers
- [x] Create comprehensive OAuth2 unit tests with benchmarks
-- [ ] Add SAML authentication support
-- [ ] Create user session management
+- [x] Add SAML authentication support
+- [x] Create user session management
- [x] Implement role-based access control (RBAC)
-- [ ] Add multi-tenant authentication support
+- [x] Add multi-tenant authentication support
### Permission System Enhancement
- [x] Implement hierarchical permission inheritance
@@ -120,7 +120,7 @@ This document outlines the complete roadmap for making the API Key Management Se
- [x] Create authentication failure tracking
### Audit & Compliance
-- [ ] Implement comprehensive audit logging
+- [x] Implement comprehensive audit logging
- [ ] Add compliance reporting features
- [ ] Create data retention policies
- [ ] Implement GDPR compliance features
diff --git a/go.mod b/go.mod
index da5ae31..87d86ea 100644
--- a/go.mod
+++ b/go.mod
@@ -21,6 +21,7 @@ require (
)
require (
+ github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
@@ -33,6 +34,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
+ github.com/jmoiron/sqlx v1.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
diff --git a/go.sum b/go.sum
index 2555ccc..25c0ee3 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,8 @@
+filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
+github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
+github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -43,6 +46,7 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
+github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
@@ -65,10 +69,13 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
+github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
+github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg=
github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
@@ -78,6 +85,7 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
diff --git a/internal/audit/audit.go b/internal/audit/audit.go
new file mode 100644
index 0000000..106ba39
--- /dev/null
+++ b/internal/audit/audit.go
@@ -0,0 +1,590 @@
+package audit
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/google/uuid"
+ "go.uber.org/zap"
+
+ "github.com/kms/api-key-service/internal/config"
+)
+
+// EventType represents the type of audit event
+type EventType string
+
+const (
+ // Authentication events
+ EventTypeLogin EventType = "auth.login"
+ EventTypeLoginFailed EventType = "auth.login_failed"
+ EventTypeLogout EventType = "auth.logout"
+ EventTypeTokenCreated EventType = "auth.token_created"
+ EventTypeTokenRevoked EventType = "auth.token_revoked"
+ EventTypeTokenValidated EventType = "auth.token_validated"
+
+ // Session events
+ EventTypeSessionCreated EventType = "session.created"
+ EventTypeSessionRevoked EventType = "session.revoked"
+ EventTypeSessionExpired EventType = "session.expired"
+
+ // Application events
+ EventTypeAppCreated EventType = "app.created"
+ EventTypeAppUpdated EventType = "app.updated"
+ EventTypeAppDeleted EventType = "app.deleted"
+
+ // Permission events
+ EventTypePermissionGranted EventType = "permission.granted"
+ EventTypePermissionRevoked EventType = "permission.revoked"
+ EventTypePermissionDenied EventType = "permission.denied"
+
+ // Tenant events
+ EventTypeTenantCreated EventType = "tenant.created"
+ EventTypeTenantUpdated EventType = "tenant.updated"
+ EventTypeTenantSuspended EventType = "tenant.suspended"
+ EventTypeTenantActivated EventType = "tenant.activated"
+
+ // User events
+ EventTypeUserCreated EventType = "user.created"
+ EventTypeUserUpdated EventType = "user.updated"
+ EventTypeUserSuspended EventType = "user.suspended"
+ EventTypeUserActivated EventType = "user.activated"
+
+ // Security events
+ EventTypeSecurityViolation EventType = "security.violation"
+ EventTypeBruteForceAttempt EventType = "security.brute_force"
+ EventTypeIPBlocked EventType = "security.ip_blocked"
+ EventTypeRateLimitExceeded EventType = "security.rate_limit_exceeded"
+
+ // System events
+ EventTypeSystemStartup EventType = "system.startup"
+ EventTypeSystemShutdown EventType = "system.shutdown"
+ EventTypeConfigChanged EventType = "system.config_changed"
+)
+
+// EventSeverity represents the severity level of an audit event
+type EventSeverity string
+
+const (
+ SeverityInfo EventSeverity = "info"
+ SeverityWarning EventSeverity = "warning"
+ SeverityError EventSeverity = "error"
+ SeverityCritical EventSeverity = "critical"
+)
+
+// EventStatus represents the status of an audit event
+type EventStatus string
+
+const (
+ StatusSuccess EventStatus = "success"
+ StatusFailure EventStatus = "failure"
+ StatusPending EventStatus = "pending"
+)
+
+// AuditEvent represents a single audit event
+type AuditEvent struct {
+ ID uuid.UUID `json:"id" db:"id"`
+ Type EventType `json:"type" db:"type"`
+ Severity EventSeverity `json:"severity" db:"severity"`
+ Status EventStatus `json:"status" db:"status"`
+ Timestamp time.Time `json:"timestamp" db:"timestamp"`
+
+ // Actor information
+ ActorID string `json:"actor_id,omitempty" db:"actor_id"`
+ ActorType string `json:"actor_type,omitempty" db:"actor_type"` // user, system, service
+ ActorIP string `json:"actor_ip,omitempty" db:"actor_ip"`
+ UserAgent string `json:"user_agent,omitempty" db:"user_agent"`
+
+ // Tenant information
+ TenantID *uuid.UUID `json:"tenant_id,omitempty" db:"tenant_id"`
+
+ // Resource information
+ ResourceID string `json:"resource_id,omitempty" db:"resource_id"`
+ ResourceType string `json:"resource_type,omitempty" db:"resource_type"`
+
+ // Event details
+ Action string `json:"action" db:"action"`
+ Description string `json:"description" db:"description"`
+ Details map[string]interface{} `json:"details,omitempty" db:"details"`
+
+ // Request context
+ RequestID string `json:"request_id,omitempty" db:"request_id"`
+ SessionID string `json:"session_id,omitempty" db:"session_id"`
+
+ // Additional metadata
+ Tags []string `json:"tags,omitempty" db:"tags"`
+ Metadata map[string]string `json:"metadata,omitempty" db:"metadata"`
+}
+
+// AuditLogger defines the interface for audit logging
+type AuditLogger interface {
+ // LogEvent logs a single audit event
+ LogEvent(ctx context.Context, event *AuditEvent) error
+
+ // LogAuthEvent logs an authentication-related event
+ LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error
+
+ // LogPermissionEvent logs a permission-related event
+ LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error
+
+ // LogSecurityEvent logs a security-related event
+ LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error
+
+ // LogSystemEvent logs a system-related event
+ LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error
+
+ // QueryEvents queries audit events with filters
+ QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error)
+
+ // GetEventStats returns audit event statistics
+ GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error)
+}
+
+// AuditFilter represents filters for querying audit events
+type AuditFilter struct {
+ EventTypes []EventType `json:"event_types,omitempty"`
+ Severities []EventSeverity `json:"severities,omitempty"`
+ Statuses []EventStatus `json:"statuses,omitempty"`
+ ActorID string `json:"actor_id,omitempty"`
+ ActorType string `json:"actor_type,omitempty"`
+ TenantID *uuid.UUID `json:"tenant_id,omitempty"`
+ ResourceID string `json:"resource_id,omitempty"`
+ ResourceType string `json:"resource_type,omitempty"`
+ StartTime *time.Time `json:"start_time,omitempty"`
+ EndTime *time.Time `json:"end_time,omitempty"`
+ Tags []string `json:"tags,omitempty"`
+ Limit int `json:"limit"`
+ Offset int `json:"offset"`
+ OrderBy string `json:"order_by"` // timestamp, type, severity
+ OrderDesc bool `json:"order_desc"`
+}
+
+// AuditStatsFilter represents filters for audit statistics
+type AuditStatsFilter struct {
+ EventTypes []EventType `json:"event_types,omitempty"`
+ TenantID *uuid.UUID `json:"tenant_id,omitempty"`
+ StartTime *time.Time `json:"start_time,omitempty"`
+ EndTime *time.Time `json:"end_time,omitempty"`
+ GroupBy string `json:"group_by"` // type, severity, status, hour, day
+}
+
+// AuditStats represents audit event statistics
+type AuditStats struct {
+ TotalEvents int `json:"total_events"`
+ ByType map[EventType]int `json:"by_type"`
+ BySeverity map[EventSeverity]int `json:"by_severity"`
+ ByStatus map[EventStatus]int `json:"by_status"`
+ ByTime map[string]int `json:"by_time,omitempty"`
+}
+
+// auditLogger implements the AuditLogger interface
+type auditLogger struct {
+ config config.ConfigProvider
+ logger *zap.Logger
+ repository AuditRepository
+}
+
+// AuditRepository defines the interface for audit event storage
+type AuditRepository interface {
+ Create(ctx context.Context, event *AuditEvent) error
+ Query(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error)
+ GetStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error)
+ DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error)
+}
+
+// NewAuditLogger creates a new audit logger
+func NewAuditLogger(config config.ConfigProvider, logger *zap.Logger, repository AuditRepository) AuditLogger {
+ return &auditLogger{
+ config: config,
+ logger: logger,
+ repository: repository,
+ }
+}
+
+// LogEvent logs a single audit event
+func (a *auditLogger) LogEvent(ctx context.Context, event *AuditEvent) error {
+ // Set default values
+ if event.ID == uuid.Nil {
+ event.ID = uuid.New()
+ }
+ if event.Timestamp.IsZero() {
+ event.Timestamp = time.Now().UTC()
+ }
+ if event.Severity == "" {
+ event.Severity = SeverityInfo
+ }
+ if event.Status == "" {
+ event.Status = StatusSuccess
+ }
+
+ // Extract request context if available
+ if requestID := ctx.Value("request_id"); requestID != nil {
+ if reqID, ok := requestID.(string); ok {
+ event.RequestID = reqID
+ }
+ }
+
+ // Log to structured logger
+ a.logToStructuredLogger(event)
+
+ // Store in repository
+ if err := a.repository.Create(ctx, event); err != nil {
+ a.logger.Error("Failed to store audit event",
+ zap.Error(err),
+ zap.String("event_id", event.ID.String()),
+ zap.String("event_type", string(event.Type)))
+ return err
+ }
+
+ return nil
+}
+
+// LogAuthEvent logs an authentication-related event
+func (a *auditLogger) LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error {
+ severity := SeverityInfo
+ status := StatusSuccess
+
+ // Determine severity and status based on event type
+ switch eventType {
+ case EventTypeLoginFailed:
+ severity = SeverityWarning
+ status = StatusFailure
+ case EventTypeTokenRevoked:
+ severity = SeverityWarning
+ }
+
+ event := &AuditEvent{
+ Type: eventType,
+ Severity: severity,
+ Status: status,
+ ActorID: actorID,
+ ActorType: "user",
+ ActorIP: actorIP,
+ Action: string(eventType),
+ Description: a.generateDescription(eventType, details),
+ Details: details,
+ Tags: []string{"authentication"},
+ }
+
+ return a.LogEvent(ctx, event)
+}
+
+// LogPermissionEvent logs a permission-related event
+func (a *auditLogger) LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error {
+ severity := SeverityInfo
+ status := StatusSuccess
+
+ // Determine severity and status based on event type
+ switch eventType {
+ case EventTypePermissionDenied:
+ severity = SeverityWarning
+ status = StatusFailure
+ }
+
+ event := &AuditEvent{
+ Type: eventType,
+ Severity: severity,
+ Status: status,
+ ActorID: actorID,
+ ActorType: "user",
+ ResourceID: resourceID,
+ ResourceType: resourceType,
+ Action: string(eventType),
+ Description: a.generateDescription(eventType, details),
+ Details: details,
+ Tags: []string{"permission", "authorization"},
+ }
+
+ return a.LogEvent(ctx, event)
+}
+
+// LogSecurityEvent logs a security-related event
+func (a *auditLogger) LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error {
+ status := StatusSuccess
+ if severity == SeverityError || severity == SeverityCritical {
+ status = StatusFailure
+ }
+
+ event := &AuditEvent{
+ Type: eventType,
+ Severity: severity,
+ Status: status,
+ ActorIP: actorIP,
+ ActorType: "system",
+ Action: string(eventType),
+ Description: a.generateDescription(eventType, details),
+ Details: details,
+ Tags: []string{"security"},
+ }
+
+ return a.LogEvent(ctx, event)
+}
+
+// LogSystemEvent logs a system-related event
+func (a *auditLogger) LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error {
+ event := &AuditEvent{
+ Type: eventType,
+ Severity: SeverityInfo,
+ Status: StatusSuccess,
+ ActorType: "system",
+ Action: string(eventType),
+ Description: a.generateDescription(eventType, details),
+ Details: details,
+ Tags: []string{"system"},
+ }
+
+ return a.LogEvent(ctx, event)
+}
+
+// QueryEvents queries audit events with filters
+func (a *auditLogger) QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error) {
+ // Set default pagination
+ if filter.Limit <= 0 {
+ filter.Limit = 100
+ }
+ if filter.Limit > 1000 {
+ filter.Limit = 1000
+ }
+ if filter.OrderBy == "" {
+ filter.OrderBy = "timestamp"
+ filter.OrderDesc = true
+ }
+
+ return a.repository.Query(ctx, filter)
+}
+
+// GetEventStats returns audit event statistics
+func (a *auditLogger) GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error) {
+ return a.repository.GetStats(ctx, filter)
+}
+
+// logToStructuredLogger logs the event to the structured logger
+func (a *auditLogger) logToStructuredLogger(event *AuditEvent) {
+ fields := []zap.Field{
+ zap.String("audit_event_id", event.ID.String()),
+ zap.String("event_type", string(event.Type)),
+ zap.String("severity", string(event.Severity)),
+ zap.String("status", string(event.Status)),
+ zap.Time("timestamp", event.Timestamp),
+ zap.String("action", event.Action),
+ zap.String("description", event.Description),
+ }
+
+ if event.ActorID != "" {
+ fields = append(fields, zap.String("actor_id", event.ActorID))
+ }
+ if event.ActorType != "" {
+ fields = append(fields, zap.String("actor_type", event.ActorType))
+ }
+ if event.ActorIP != "" {
+ fields = append(fields, zap.String("actor_ip", event.ActorIP))
+ }
+ if event.TenantID != nil {
+ fields = append(fields, zap.String("tenant_id", event.TenantID.String()))
+ }
+ if event.ResourceID != "" {
+ fields = append(fields, zap.String("resource_id", event.ResourceID))
+ }
+ if event.ResourceType != "" {
+ fields = append(fields, zap.String("resource_type", event.ResourceType))
+ }
+ if event.RequestID != "" {
+ fields = append(fields, zap.String("request_id", event.RequestID))
+ }
+ if event.SessionID != "" {
+ fields = append(fields, zap.String("session_id", event.SessionID))
+ }
+ if len(event.Tags) > 0 {
+ fields = append(fields, zap.Strings("tags", event.Tags))
+ }
+ if event.Details != nil {
+ if detailsJSON, err := json.Marshal(event.Details); err == nil {
+ fields = append(fields, zap.String("details", string(detailsJSON)))
+ }
+ }
+
+ // Log at appropriate level based on severity
+ switch event.Severity {
+ case SeverityInfo:
+ a.logger.Info("Audit event", fields...)
+ case SeverityWarning:
+ a.logger.Warn("Audit event", fields...)
+ case SeverityError:
+ a.logger.Error("Audit event", fields...)
+ case SeverityCritical:
+ a.logger.Error("Critical audit event", fields...)
+ default:
+ a.logger.Info("Audit event", fields...)
+ }
+}
+
+// generateDescription generates a human-readable description for an event
+func (a *auditLogger) generateDescription(eventType EventType, details map[string]interface{}) string {
+ switch eventType {
+ case EventTypeLogin:
+ return "User successfully logged in"
+ case EventTypeLoginFailed:
+ return "User login attempt failed"
+ case EventTypeLogout:
+ return "User logged out"
+ case EventTypeTokenCreated:
+ return "API token created"
+ case EventTypeTokenRevoked:
+ return "API token revoked"
+ case EventTypeTokenValidated:
+ return "API token validated"
+ case EventTypeSessionCreated:
+ return "User session created"
+ case EventTypeSessionRevoked:
+ return "User session revoked"
+ case EventTypeSessionExpired:
+ return "User session expired"
+ case EventTypeAppCreated:
+ return "Application created"
+ case EventTypeAppUpdated:
+ return "Application updated"
+ case EventTypeAppDeleted:
+ return "Application deleted"
+ case EventTypePermissionGranted:
+ return "Permission granted"
+ case EventTypePermissionRevoked:
+ return "Permission revoked"
+ case EventTypePermissionDenied:
+ return "Permission denied"
+ case EventTypeTenantCreated:
+ return "Tenant created"
+ case EventTypeTenantUpdated:
+ return "Tenant updated"
+ case EventTypeTenantSuspended:
+ return "Tenant suspended"
+ case EventTypeTenantActivated:
+ return "Tenant activated"
+ case EventTypeUserCreated:
+ return "User created"
+ case EventTypeUserUpdated:
+ return "User updated"
+ case EventTypeUserSuspended:
+ return "User suspended"
+ case EventTypeUserActivated:
+ return "User activated"
+ case EventTypeSecurityViolation:
+ return "Security violation detected"
+ case EventTypeBruteForceAttempt:
+ return "Brute force attempt detected"
+ case EventTypeIPBlocked:
+ return "IP address blocked"
+ case EventTypeRateLimitExceeded:
+ return "Rate limit exceeded"
+ case EventTypeSystemStartup:
+ return "System started"
+ case EventTypeSystemShutdown:
+ return "System shutdown"
+ case EventTypeConfigChanged:
+ return "Configuration changed"
+ default:
+ return string(eventType)
+ }
+}
+
+// AuditEventBuilder provides a fluent interface for building audit events
+type AuditEventBuilder struct {
+ event *AuditEvent
+}
+
+// NewAuditEventBuilder creates a new audit event builder
+func NewAuditEventBuilder(eventType EventType) *AuditEventBuilder {
+ return &AuditEventBuilder{
+ event: &AuditEvent{
+ ID: uuid.New(),
+ Type: eventType,
+ Timestamp: time.Now().UTC(),
+ Severity: SeverityInfo,
+ Status: StatusSuccess,
+ Details: make(map[string]interface{}),
+ Metadata: make(map[string]string),
+ },
+ }
+}
+
+// WithSeverity sets the event severity
+func (b *AuditEventBuilder) WithSeverity(severity EventSeverity) *AuditEventBuilder {
+ b.event.Severity = severity
+ return b
+}
+
+// WithStatus sets the event status
+func (b *AuditEventBuilder) WithStatus(status EventStatus) *AuditEventBuilder {
+ b.event.Status = status
+ return b
+}
+
+// WithActor sets the actor information
+func (b *AuditEventBuilder) WithActor(actorID, actorType, actorIP string) *AuditEventBuilder {
+ b.event.ActorID = actorID
+ b.event.ActorType = actorType
+ b.event.ActorIP = actorIP
+ return b
+}
+
+// WithTenant sets the tenant ID
+func (b *AuditEventBuilder) WithTenant(tenantID uuid.UUID) *AuditEventBuilder {
+ b.event.TenantID = &tenantID
+ return b
+}
+
+// WithResource sets the resource information
+func (b *AuditEventBuilder) WithResource(resourceID, resourceType string) *AuditEventBuilder {
+ b.event.ResourceID = resourceID
+ b.event.ResourceType = resourceType
+ return b
+}
+
+// WithAction sets the action
+func (b *AuditEventBuilder) WithAction(action string) *AuditEventBuilder {
+ b.event.Action = action
+ return b
+}
+
+// WithDescription sets the description
+func (b *AuditEventBuilder) WithDescription(description string) *AuditEventBuilder {
+ b.event.Description = description
+ return b
+}
+
+// WithDetail adds a detail
+func (b *AuditEventBuilder) WithDetail(key string, value interface{}) *AuditEventBuilder {
+ b.event.Details[key] = value
+ return b
+}
+
+// WithDetails sets multiple details
+func (b *AuditEventBuilder) WithDetails(details map[string]interface{}) *AuditEventBuilder {
+ for k, v := range details {
+ b.event.Details[k] = v
+ }
+ return b
+}
+
+// WithRequestContext sets request context information
+func (b *AuditEventBuilder) WithRequestContext(requestID, sessionID string) *AuditEventBuilder {
+ b.event.RequestID = requestID
+ b.event.SessionID = sessionID
+ return b
+}
+
+// WithTags sets the tags
+func (b *AuditEventBuilder) WithTags(tags ...string) *AuditEventBuilder {
+ b.event.Tags = tags
+ return b
+}
+
+// WithMetadata adds metadata
+func (b *AuditEventBuilder) WithMetadata(key, value string) *AuditEventBuilder {
+ b.event.Metadata[key] = value
+ return b
+}
+
+// Build returns the built audit event
+func (b *AuditEventBuilder) Build() *AuditEvent {
+ return b.event
+}
diff --git a/internal/auth/saml.go b/internal/auth/saml.go
new file mode 100644
index 0000000..c4e5a32
--- /dev/null
+++ b/internal/auth/saml.go
@@ -0,0 +1,544 @@
+package auth
+
+import (
+ "context"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/pem"
+ "encoding/xml"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "go.uber.org/zap"
+
+ "github.com/kms/api-key-service/internal/config"
+ "github.com/kms/api-key-service/internal/domain"
+ "github.com/kms/api-key-service/internal/errors"
+)
+
+// SAMLProvider represents a SAML 2.0 identity provider
+type SAMLProvider struct {
+ config config.ConfigProvider
+ logger *zap.Logger
+ httpClient *http.Client
+ privateKey *rsa.PrivateKey
+ certificate *x509.Certificate
+}
+
+// NewSAMLProvider creates a new SAML provider
+func NewSAMLProvider(config config.ConfigProvider, logger *zap.Logger) (*SAMLProvider, error) {
+ provider := &SAMLProvider{
+ config: config,
+ logger: logger,
+ httpClient: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ }
+
+ // Load SP private key and certificate if configured
+ if err := provider.loadCredentials(); err != nil {
+ return nil, err
+ }
+
+ return provider, nil
+}
+
+// SAMLMetadata represents SAML IdP metadata
+type SAMLMetadata struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"`
+ EntityID string `xml:"entityID,attr"`
+ IDPSSODescriptor IDPSSODescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata IDPSSODescriptor"`
+}
+
+// IDPSSODescriptor represents the IdP SSO descriptor
+type IDPSSODescriptor struct {
+ ProtocolSupportEnumeration string `xml:"protocolSupportEnumeration,attr"`
+ KeyDescriptor []KeyDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata KeyDescriptor"`
+ SingleSignOnService []SingleSignOnService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleSignOnService"`
+ SingleLogoutService []SingleLogoutService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleLogoutService"`
+}
+
+// KeyDescriptor represents a key descriptor
+type KeyDescriptor struct {
+ Use string `xml:"use,attr"`
+ KeyInfo KeyInfo `xml:"urn:xmldsig KeyInfo"`
+}
+
+// KeyInfo represents key information
+type KeyInfo struct {
+ X509Data X509Data `xml:"urn:xmldsig X509Data"`
+}
+
+// X509Data represents X509 certificate data
+type X509Data struct {
+ X509Certificate string `xml:"urn:xmldsig X509Certificate"`
+}
+
+// SingleSignOnService represents SSO service endpoint
+type SingleSignOnService struct {
+ Binding string `xml:"Binding,attr"`
+ Location string `xml:"Location,attr"`
+}
+
+// SingleLogoutService represents SLO service endpoint
+type SingleLogoutService struct {
+ Binding string `xml:"Binding,attr"`
+ Location string `xml:"Location,attr"`
+}
+
+// SAMLRequest represents a SAML authentication request
+type SAMLRequest struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol AuthnRequest"`
+ ID string `xml:"ID,attr"`
+ Version string `xml:"Version,attr"`
+ IssueInstant time.Time `xml:"IssueInstant,attr"`
+ Destination string `xml:"Destination,attr"`
+ AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"`
+ ProtocolBinding string `xml:"ProtocolBinding,attr"`
+ Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
+ NameIDPolicy NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
+}
+
+// Issuer represents the SAML issuer
+type Issuer struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
+ Value string `xml:",chardata"`
+}
+
+// NameIDPolicy represents the name ID policy
+type NameIDPolicy struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
+ Format string `xml:"Format,attr"`
+}
+
+// SAMLResponse represents a SAML response
+type SAMLResponse struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"`
+ ID string `xml:"ID,attr"`
+ Version string `xml:"Version,attr"`
+ IssueInstant time.Time `xml:"IssueInstant,attr"`
+ Destination string `xml:"Destination,attr"`
+ InResponseTo string `xml:"InResponseTo,attr"`
+ Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
+ Status Status `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
+ Assertion Assertion `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
+}
+
+// Status represents the SAML response status
+type Status struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
+ StatusCode StatusCode `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
+}
+
+// StatusCode represents the status code
+type StatusCode struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
+ Value string `xml:"Value,attr"`
+}
+
+// Assertion represents a SAML assertion
+type Assertion struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
+ ID string `xml:"ID,attr"`
+ Version string `xml:"Version,attr"`
+ IssueInstant time.Time `xml:"IssueInstant,attr"`
+ Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
+ Subject Subject `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
+ Conditions Conditions `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
+ AttributeStatement AttributeStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
+ AuthnStatement AuthnStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
+}
+
+// Subject represents the assertion subject
+type Subject struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
+ NameID NameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
+ SubjectConfirmation SubjectConfirmation `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
+}
+
+// NameID represents the name identifier
+type NameID struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
+ Format string `xml:"Format,attr"`
+ Value string `xml:",chardata"`
+}
+
+// SubjectConfirmation represents subject confirmation
+type SubjectConfirmation struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
+ Method string `xml:"Method,attr"`
+ SubjectConfirmationData SubjectConfirmationData `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
+}
+
+// SubjectConfirmationData represents subject confirmation data
+type SubjectConfirmationData struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
+ InResponseTo string `xml:"InResponseTo,attr"`
+ NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
+ Recipient string `xml:"Recipient,attr"`
+}
+
+// Conditions represents assertion conditions
+type Conditions struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
+ NotBefore time.Time `xml:"NotBefore,attr"`
+ NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
+ AudienceRestriction AudienceRestriction `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
+}
+
+// AudienceRestriction represents audience restriction
+type AudienceRestriction struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
+ Audience Audience `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
+}
+
+// Audience represents the intended audience
+type Audience struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
+ Value string `xml:",chardata"`
+}
+
+// AttributeStatement represents attribute statement
+type AttributeStatement struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
+ Attribute []Attribute `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
+}
+
+// Attribute represents a SAML attribute
+type Attribute struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
+ Name string `xml:"Name,attr"`
+ AttributeValue []AttributeValue `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
+}
+
+// AttributeValue represents an attribute value
+type AttributeValue struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
+ Type string `xml:"http://www.w3.org/2001/XMLSchema-instance type,attr"`
+ Value string `xml:",chardata"`
+}
+
+// AuthnStatement represents authentication statement
+type AuthnStatement struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
+ AuthnInstant time.Time `xml:"AuthnInstant,attr"`
+ SessionIndex string `xml:"SessionIndex,attr"`
+ AuthnContext AuthnContext `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
+}
+
+// AuthnContext represents authentication context
+type AuthnContext struct {
+ XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
+ AuthnContextClassRef string `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContextClassRef"`
+}
+
+// GetMetadata fetches the SAML IdP metadata
+func (p *SAMLProvider) GetMetadata(ctx context.Context) (*SAMLMetadata, error) {
+ metadataURL := p.config.GetString("SAML_IDP_METADATA_URL")
+ if metadataURL == "" {
+ return nil, errors.NewConfigurationError("SAML_IDP_METADATA_URL not configured")
+ }
+
+ p.logger.Debug("Fetching SAML IdP metadata", zap.String("url", metadataURL))
+
+ req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
+ if err != nil {
+ return nil, errors.NewInternalError("Failed to create metadata request").WithInternal(err)
+ }
+
+ resp, err := p.httpClient.Do(req)
+ if err != nil {
+ return nil, errors.NewInternalError("Failed to fetch IdP metadata").WithInternal(err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, errors.NewInternalError(fmt.Sprintf("Metadata endpoint returned status %d", resp.StatusCode))
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, errors.NewInternalError("Failed to read metadata response").WithInternal(err)
+ }
+
+ var metadata SAMLMetadata
+ if err := xml.Unmarshal(body, &metadata); err != nil {
+ return nil, errors.NewInternalError("Failed to parse SAML metadata").WithInternal(err)
+ }
+
+ p.logger.Debug("SAML IdP metadata fetched successfully",
+ zap.String("entity_id", metadata.EntityID))
+
+ return &metadata, nil
+}
+
+// GenerateAuthRequest generates a SAML authentication request
+func (p *SAMLProvider) GenerateAuthRequest(ctx context.Context, relayState string) (string, string, error) {
+ metadata, err := p.GetMetadata(ctx)
+ if err != nil {
+ return "", "", err
+ }
+
+ // Find SSO endpoint
+ var ssoEndpoint string
+ for _, sso := range metadata.IDPSSODescriptor.SingleSignOnService {
+ if sso.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" {
+ ssoEndpoint = sso.Location
+ break
+ }
+ }
+
+ if ssoEndpoint == "" {
+ return "", "", errors.NewConfigurationError("No HTTP-Redirect SSO endpoint found in IdP metadata")
+ }
+
+ // Generate request ID
+ requestID := "_" + uuid.New().String()
+
+ // Get SP configuration
+ spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
+ acsURL := p.config.GetString("SAML_SP_ACS_URL")
+
+ if spEntityID == "" {
+ return "", "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
+ }
+ if acsURL == "" {
+ return "", "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
+ }
+
+ // Create SAML request
+ samlRequest := SAMLRequest{
+ ID: requestID,
+ Version: "2.0",
+ IssueInstant: time.Now().UTC(),
+ Destination: ssoEndpoint,
+ AssertionConsumerServiceURL: acsURL,
+ ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
+ Issuer: Issuer{
+ Value: spEntityID,
+ },
+ NameIDPolicy: NameIDPolicy{
+ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress",
+ },
+ }
+
+ // Marshal to XML
+ xmlData, err := xml.MarshalIndent(samlRequest, "", " ")
+ if err != nil {
+ return "", "", errors.NewInternalError("Failed to marshal SAML request").WithInternal(err)
+ }
+
+ // Add XML declaration
+ xmlRequest := `` + "\n" + string(xmlData)
+
+ // Base64 encode and URL encode
+ encodedRequest := base64.StdEncoding.EncodeToString([]byte(xmlRequest))
+
+ // Build redirect URL
+ params := url.Values{
+ "SAMLRequest": {encodedRequest},
+ "RelayState": {relayState},
+ }
+
+ redirectURL := ssoEndpoint + "?" + params.Encode()
+
+ p.logger.Debug("Generated SAML authentication request",
+ zap.String("request_id", requestID),
+ zap.String("sso_endpoint", ssoEndpoint))
+
+ return redirectURL, requestID, nil
+}
+
+// ProcessSAMLResponse processes a SAML response and extracts user information
+func (p *SAMLProvider) ProcessSAMLResponse(ctx context.Context, samlResponse string, expectedRequestID string) (*domain.AuthContext, error) {
+ p.logger.Debug("Processing SAML response")
+
+ // Base64 decode the response
+ decodedResponse, err := base64.StdEncoding.DecodeString(samlResponse)
+ if err != nil {
+ return nil, errors.NewValidationError("Failed to decode SAML response").WithInternal(err)
+ }
+
+ // Parse XML
+ var response SAMLResponse
+ if err := xml.Unmarshal(decodedResponse, &response); err != nil {
+ return nil, errors.NewValidationError("Failed to parse SAML response").WithInternal(err)
+ }
+
+ // Validate response
+ if err := p.validateSAMLResponse(&response, expectedRequestID); err != nil {
+ return nil, err
+ }
+
+ // Extract user information from assertion
+ authContext, err := p.extractUserInfo(&response.Assertion)
+ if err != nil {
+ return nil, err
+ }
+
+ p.logger.Debug("SAML response processed successfully",
+ zap.String("user_id", authContext.UserID))
+
+ return authContext, nil
+}
+
+// validateSAMLResponse validates a SAML response
+func (p *SAMLProvider) validateSAMLResponse(response *SAMLResponse, expectedRequestID string) error {
+ // Check status
+ if response.Status.StatusCode.Value != "urn:oasis:names:tc:SAML:2.0:status:Success" {
+ return errors.NewAuthenticationError("SAML authentication failed: " + response.Status.StatusCode.Value)
+ }
+
+ // Validate InResponseTo
+ if expectedRequestID != "" && response.InResponseTo != expectedRequestID {
+ return errors.NewValidationError("SAML response InResponseTo does not match request ID")
+ }
+
+ // Validate assertion conditions
+ assertion := &response.Assertion
+ now := time.Now().UTC()
+
+ if now.Before(assertion.Conditions.NotBefore) {
+ return errors.NewValidationError("SAML assertion not yet valid")
+ }
+
+ if now.After(assertion.Conditions.NotOnOrAfter) {
+ return errors.NewValidationError("SAML assertion has expired")
+ }
+
+ // Validate audience
+ expectedAudience := p.config.GetString("SAML_SP_ENTITY_ID")
+ if assertion.Conditions.AudienceRestriction.Audience.Value != expectedAudience {
+ return errors.NewValidationError("SAML assertion audience mismatch")
+ }
+
+ // In production, you should also validate the signature
+ // This requires implementing XML signature validation
+
+ return nil
+}
+
+// extractUserInfo extracts user information from SAML assertion
+func (p *SAMLProvider) extractUserInfo(assertion *Assertion) (*domain.AuthContext, error) {
+ // Extract user ID from NameID
+ userID := assertion.Subject.NameID.Value
+ if userID == "" {
+ return nil, errors.NewValidationError("SAML assertion missing NameID")
+ }
+
+ // Extract attributes
+ claims := make(map[string]string)
+ claims["sub"] = userID
+ claims["name_id_format"] = assertion.Subject.NameID.Format
+
+ // Process attribute statements
+ for _, attr := range assertion.AttributeStatement.Attribute {
+ if len(attr.AttributeValue) > 0 {
+ // Use the first value if multiple values exist
+ claims[attr.Name] = attr.AttributeValue[0].Value
+ }
+ }
+
+ // Map common attributes to standard claims
+ if email, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"]; exists {
+ claims["email"] = email
+ }
+ if name, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"]; exists {
+ claims["name"] = name
+ }
+ if givenName, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname"]; exists {
+ claims["given_name"] = givenName
+ }
+ if surname, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname"]; exists {
+ claims["family_name"] = surname
+ }
+
+ // Extract permissions/roles if available
+ var permissions []string
+ if roles, exists := claims["http://schemas.microsoft.com/ws/2008/06/identity/claims/role"]; exists {
+ permissions = strings.Split(roles, ",")
+ }
+
+ authContext := &domain.AuthContext{
+ UserID: userID,
+ TokenType: domain.TokenTypeUser,
+ Claims: claims,
+ Permissions: permissions,
+ }
+
+ return authContext, nil
+}
+
+// GenerateServiceProviderMetadata generates SP metadata XML
+func (p *SAMLProvider) GenerateServiceProviderMetadata() (string, error) {
+ spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
+ acsURL := p.config.GetString("SAML_SP_ACS_URL")
+
+ if spEntityID == "" {
+ return "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
+ }
+ if acsURL == "" {
+ return "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
+ }
+
+ // This is a simplified SP metadata generation
+ // In production, you should use a proper SAML library
+ metadata := fmt.Sprintf(`
+
+
+
+
+`, spEntityID, acsURL)
+
+ return metadata, nil
+}
+
+// loadCredentials loads SP private key and certificate
+func (p *SAMLProvider) loadCredentials() error {
+ // Load private key if configured
+ privateKeyPEM := p.config.GetString("SAML_SP_PRIVATE_KEY")
+ if privateKeyPEM != "" {
+ block, _ := pem.Decode([]byte(privateKeyPEM))
+ if block == nil {
+ return errors.NewConfigurationError("Failed to decode SAML SP private key")
+ }
+
+ privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ // Try PKCS8 format
+ key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
+ if err != nil {
+ return errors.NewConfigurationError("Failed to parse SAML SP private key").WithInternal(err)
+ }
+ var ok bool
+ privateKey, ok = key.(*rsa.PrivateKey)
+ if !ok {
+ return errors.NewConfigurationError("SAML SP private key is not RSA")
+ }
+ }
+ p.privateKey = privateKey
+ }
+
+ // Load certificate if configured
+ certificatePEM := p.config.GetString("SAML_SP_CERTIFICATE")
+ if certificatePEM != "" {
+ block, _ := pem.Decode([]byte(certificatePEM))
+ if block == nil {
+ return errors.NewConfigurationError("Failed to decode SAML SP certificate")
+ }
+
+ certificate, err := x509.ParseCertificate(block.Bytes)
+ if err != nil {
+ return errors.NewConfigurationError("Failed to parse SAML SP certificate").WithInternal(err)
+ }
+ p.certificate = certificate
+ }
+
+ return nil
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index c874a40..6c13b77 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -132,6 +132,12 @@ func (c *Config) setDefaults() {
"IP_BLOCK_DURATION": "1h",
"REQUEST_MAX_AGE": "5m",
"IP_WHITELIST": "",
+ "SAML_ENABLED": "false",
+ "SAML_IDP_METADATA_URL": "",
+ "SAML_SP_ENTITY_ID": "",
+ "SAML_SP_ACS_URL": "",
+ "SAML_SP_PRIVATE_KEY": "",
+ "SAML_SP_CERTIFICATE": "",
}
for key, value := range defaults {
diff --git a/internal/domain/models.go b/internal/domain/models.go
index f70a0cf..ec6508e 100644
--- a/internal/domain/models.go
+++ b/internal/domain/models.go
@@ -188,6 +188,23 @@ type CreateStaticTokenResponse struct {
CreatedAt time.Time `json:"created_at"`
}
+// CreateTokenRequest represents a request to create a token
+type CreateTokenRequest struct {
+ AppID string `json:"app_id" validate:"required"`
+ Type TokenType `json:"type" validate:"required,oneof=static user"`
+ UserID string `json:"user_id,omitempty"` // Required for user tokens
+ Permissions []string `json:"permissions,omitempty"`
+ ExpiresAt *time.Time `json:"expires_at,omitempty"`
+ Metadata map[string]string `json:"metadata,omitempty"`
+}
+
+// CreateTokenResponse represents a response for creating a token
+type CreateTokenResponse struct {
+ Token string `json:"token"`
+ ExpiresAt time.Time `json:"expires_at"`
+ TokenType TokenType `json:"token_type"`
+}
+
// AuthContext represents the authentication context for a request
type AuthContext struct {
UserID string `json:"user_id"`
@@ -196,3 +213,25 @@ type AuthContext struct {
Claims map[string]string `json:"claims"`
AppID string `json:"app_id"`
}
+
+// TokenResponse represents the OAuth2 token response
+type TokenResponse struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int `json:"expires_in"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+ IDToken string `json:"id_token,omitempty"`
+ Scope string `json:"scope,omitempty"`
+}
+
+// UserInfo represents user information from the OAuth2/OIDC provider
+type UserInfo struct {
+ Sub string `json:"sub"`
+ Email string `json:"email"`
+ EmailVerified bool `json:"email_verified"`
+ Name string `json:"name"`
+ GivenName string `json:"given_name"`
+ FamilyName string `json:"family_name"`
+ Picture string `json:"picture"`
+ PreferredUsername string `json:"preferred_username"`
+}
diff --git a/internal/domain/session.go b/internal/domain/session.go
new file mode 100644
index 0000000..dee5de3
--- /dev/null
+++ b/internal/domain/session.go
@@ -0,0 +1,153 @@
+package domain
+
+import (
+ "time"
+
+ "github.com/google/uuid"
+)
+
+// SessionStatus represents the status of a user session
+type SessionStatus string
+
+const (
+ SessionStatusActive SessionStatus = "active"
+ SessionStatusExpired SessionStatus = "expired"
+ SessionStatusRevoked SessionStatus = "revoked"
+ SessionStatusSuspended SessionStatus = "suspended"
+)
+
+// SessionType represents the type of session
+type SessionType string
+
+const (
+ SessionTypeWeb SessionType = "web"
+ SessionTypeMobile SessionType = "mobile"
+ SessionTypeAPI SessionType = "api"
+)
+
+// UserSession represents a user session in the system
+type UserSession struct {
+ ID uuid.UUID `json:"id" db:"id"`
+ UserID string `json:"user_id" validate:"required" db:"user_id"`
+ AppID string `json:"app_id" validate:"required" db:"app_id"`
+ SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api" db:"session_type"`
+ Status SessionStatus `json:"status" validate:"required,oneof=active expired revoked suspended" db:"status"`
+ AccessToken string `json:"-" db:"access_token"` // Hidden from JSON for security
+ RefreshToken string `json:"-" db:"refresh_token"` // Hidden from JSON for security
+ IDToken string `json:"-" db:"id_token"` // Hidden from JSON for security
+ IPAddress string `json:"ip_address" db:"ip_address"`
+ UserAgent string `json:"user_agent" db:"user_agent"`
+ LastActivity time.Time `json:"last_activity" db:"last_activity"`
+ ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
+ CreatedAt time.Time `json:"created_at" db:"created_at"`
+ UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
+ RevokedAt *time.Time `json:"revoked_at,omitempty" db:"revoked_at"`
+ RevokedBy *string `json:"revoked_by,omitempty" db:"revoked_by"`
+ Metadata SessionMetadata `json:"metadata" db:"metadata"`
+}
+
+// SessionMetadata contains additional session information
+type SessionMetadata struct {
+ DeviceInfo string `json:"device_info,omitempty"`
+ Location string `json:"location,omitempty"`
+ LoginMethod string `json:"login_method,omitempty"`
+ TenantID string `json:"tenant_id,omitempty"`
+ Permissions []string `json:"permissions,omitempty"`
+ Claims map[string]string `json:"claims,omitempty"`
+ RefreshCount int `json:"refresh_count"`
+ LastRefresh *time.Time `json:"last_refresh,omitempty"`
+}
+
+// CreateSessionRequest represents a request to create a new session
+type CreateSessionRequest struct {
+ UserID string `json:"user_id" validate:"required"`
+ AppID string `json:"app_id" validate:"required"`
+ SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api"`
+ IPAddress string `json:"ip_address" validate:"required,ip"`
+ UserAgent string `json:"user_agent" validate:"required"`
+ ExpiresAt time.Time `json:"expires_at" validate:"required"`
+ Permissions []string `json:"permissions,omitempty"`
+ Claims map[string]string `json:"claims,omitempty"`
+ TenantID string `json:"tenant_id,omitempty"`
+}
+
+// UpdateSessionRequest represents a request to update a session
+type UpdateSessionRequest struct {
+ Status *SessionStatus `json:"status,omitempty" validate:"omitempty,oneof=active expired revoked suspended"`
+ LastActivity *time.Time `json:"last_activity,omitempty"`
+ ExpiresAt *time.Time `json:"expires_at,omitempty"`
+ IPAddress *string `json:"ip_address,omitempty" validate:"omitempty,ip"`
+ UserAgent *string `json:"user_agent,omitempty"`
+}
+
+// SessionListRequest represents a request to list sessions
+type SessionListRequest struct {
+ UserID string `json:"user_id,omitempty"`
+ AppID string `json:"app_id,omitempty"`
+ Status *SessionStatus `json:"status,omitempty"`
+ SessionType *SessionType `json:"session_type,omitempty"`
+ TenantID string `json:"tenant_id,omitempty"`
+ Limit int `json:"limit" validate:"min=1,max=100"`
+ Offset int `json:"offset" validate:"min=0"`
+}
+
+// SessionListResponse represents a response for listing sessions
+type SessionListResponse struct {
+ Sessions []*UserSession `json:"sessions"`
+ Total int `json:"total"`
+ Limit int `json:"limit"`
+ Offset int `json:"offset"`
+}
+
+// IsActive checks if the session is currently active
+func (s *UserSession) IsActive() bool {
+ return s.Status == SessionStatusActive && time.Now().Before(s.ExpiresAt)
+}
+
+// IsExpired checks if the session has expired
+func (s *UserSession) IsExpired() bool {
+ return time.Now().After(s.ExpiresAt) || s.Status == SessionStatusExpired
+}
+
+// IsRevoked checks if the session has been revoked
+func (s *UserSession) IsRevoked() bool {
+ return s.Status == SessionStatusRevoked
+}
+
+// CanRefresh checks if the session can be refreshed
+func (s *UserSession) CanRefresh() bool {
+ return s.IsActive() && s.RefreshToken != ""
+}
+
+// UpdateActivity updates the last activity timestamp
+func (s *UserSession) UpdateActivity() {
+ s.LastActivity = time.Now()
+ s.UpdatedAt = time.Now()
+}
+
+// Revoke marks the session as revoked
+func (s *UserSession) Revoke(revokedBy string) {
+ now := time.Now()
+ s.Status = SessionStatusRevoked
+ s.RevokedAt = &now
+ s.RevokedBy = &revokedBy
+ s.UpdatedAt = now
+}
+
+// Expire marks the session as expired
+func (s *UserSession) Expire() {
+ s.Status = SessionStatusExpired
+ s.UpdatedAt = time.Now()
+}
+
+// Suspend marks the session as suspended
+func (s *UserSession) Suspend() {
+ s.Status = SessionStatusSuspended
+ s.UpdatedAt = time.Now()
+}
+
+// Activate marks the session as active
+func (s *UserSession) Activate() {
+ s.Status = SessionStatusActive
+ s.UpdatedAt = time.Now()
+}
diff --git a/internal/domain/tenant.go b/internal/domain/tenant.go
new file mode 100644
index 0000000..bf2313f
--- /dev/null
+++ b/internal/domain/tenant.go
@@ -0,0 +1,307 @@
+package domain
+
+import (
+ "time"
+
+ "github.com/google/uuid"
+)
+
+// TenantStatus represents the status of a tenant
+type TenantStatus string
+
+const (
+ TenantStatusActive TenantStatus = "active"
+ TenantStatusSuspended TenantStatus = "suspended"
+ TenantStatusInactive TenantStatus = "inactive"
+)
+
+// Tenant represents a tenant in the multi-tenant system
+type Tenant struct {
+ ID uuid.UUID `json:"id" db:"id"`
+ Name string `json:"name" validate:"required,min=1,max=255" db:"name"`
+ Slug string `json:"slug" validate:"required,min=1,max=100,alphanum" db:"slug"`
+ Status TenantStatus `json:"status" validate:"required,oneof=active suspended inactive" db:"status"`
+ Domain string `json:"domain,omitempty" validate:"omitempty,fqdn" db:"domain"`
+ Description string `json:"description,omitempty" validate:"max=1000" db:"description"`
+ Settings TenantSettings `json:"settings" db:"settings"`
+ CreatedAt time.Time `json:"created_at" db:"created_at"`
+ UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
+ CreatedBy string `json:"created_by" db:"created_by"`
+ UpdatedBy string `json:"updated_by" db:"updated_by"`
+}
+
+// TenantSettings contains tenant-specific configuration
+type TenantSettings struct {
+ // Authentication settings
+ AuthProvider string `json:"auth_provider,omitempty"` // oauth2, saml, header
+ SAMLSettings *SAMLSettings `json:"saml_settings,omitempty"`
+ OAuth2Settings *OAuth2Settings `json:"oauth2_settings,omitempty"`
+
+ // Session settings
+ SessionTimeout time.Duration `json:"session_timeout,omitempty"`
+ MaxConcurrentSessions int `json:"max_concurrent_sessions,omitempty"`
+
+ // Security settings
+ RequireMFA bool `json:"require_mfa"`
+ AllowedIPRanges []string `json:"allowed_ip_ranges,omitempty"`
+ PasswordPolicy *PasswordPolicy `json:"password_policy,omitempty"`
+
+ // Token settings
+ DefaultTokenDuration time.Duration `json:"default_token_duration,omitempty"`
+ MaxTokenDuration time.Duration `json:"max_token_duration,omitempty"`
+
+ // Feature flags
+ Features map[string]bool `json:"features,omitempty"`
+
+ // Custom attributes
+ CustomAttributes map[string]string `json:"custom_attributes,omitempty"`
+}
+
+// SAMLSettings contains SAML-specific configuration for a tenant
+type SAMLSettings struct {
+ IDPMetadataURL string `json:"idp_metadata_url,omitempty"`
+ SPEntityID string `json:"sp_entity_id,omitempty"`
+ ACSURL string `json:"acs_url,omitempty"`
+ SPPrivateKey string `json:"sp_private_key,omitempty"`
+ SPCertificate string `json:"sp_certificate,omitempty"`
+ AttributeMapping map[string]string `json:"attribute_mapping,omitempty"`
+}
+
+// OAuth2Settings contains OAuth2-specific configuration for a tenant
+type OAuth2Settings struct {
+ ProviderURL string `json:"provider_url,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ ClientSecret string `json:"client_secret,omitempty"`
+ Scopes []string `json:"scopes,omitempty"`
+ AttributeMapping map[string]string `json:"attribute_mapping,omitempty"`
+}
+
+// PasswordPolicy defines password requirements for a tenant
+type PasswordPolicy struct {
+ MinLength int `json:"min_length"`
+ RequireUppercase bool `json:"require_uppercase"`
+ RequireLowercase bool `json:"require_lowercase"`
+ RequireNumbers bool `json:"require_numbers"`
+ RequireSymbols bool `json:"require_symbols"`
+ MaxAge time.Duration `json:"max_age,omitempty"`
+ PreventReuse int `json:"prevent_reuse"` // Number of previous passwords to prevent reuse
+}
+
+// TenantUser represents a user within a specific tenant
+type TenantUser struct {
+ ID uuid.UUID `json:"id" db:"id"`
+ TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"`
+ UserID string `json:"user_id" validate:"required" db:"user_id"`
+ Email string `json:"email" validate:"required,email" db:"email"`
+ Name string `json:"name" validate:"required" db:"name"`
+ Roles []string `json:"roles" db:"roles"`
+ Permissions []string `json:"permissions" db:"permissions"`
+ Status UserStatus `json:"status" validate:"required,oneof=active inactive suspended" db:"status"`
+ Metadata map[string]string `json:"metadata,omitempty" db:"metadata"`
+ CreatedAt time.Time `json:"created_at" db:"created_at"`
+ UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
+ LastLoginAt *time.Time `json:"last_login_at,omitempty" db:"last_login_at"`
+}
+
+// UserStatus represents the status of a user within a tenant
+type UserStatus string
+
+const (
+ UserStatusActive UserStatus = "active"
+ UserStatusInactive UserStatus = "inactive"
+ UserStatusSuspended UserStatus = "suspended"
+)
+
+// TenantRole represents a role within a tenant
+type TenantRole struct {
+ ID uuid.UUID `json:"id" db:"id"`
+ TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"`
+ Name string `json:"name" validate:"required,min=1,max=100" db:"name"`
+ Description string `json:"description,omitempty" validate:"max=500" db:"description"`
+ Permissions []string `json:"permissions" db:"permissions"`
+ IsSystem bool `json:"is_system" db:"is_system"` // System roles cannot be deleted
+ CreatedAt time.Time `json:"created_at" db:"created_at"`
+ UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
+ CreatedBy string `json:"created_by" db:"created_by"`
+ UpdatedBy string `json:"updated_by" db:"updated_by"`
+}
+
+// CreateTenantRequest represents a request to create a new tenant
+type CreateTenantRequest struct {
+ Name string `json:"name" validate:"required,min=1,max=255"`
+ Slug string `json:"slug" validate:"required,min=1,max=100,alphanum"`
+ Domain string `json:"domain,omitempty" validate:"omitempty,fqdn"`
+ Description string `json:"description,omitempty" validate:"max=1000"`
+ Settings TenantSettings `json:"settings,omitempty"`
+}
+
+// UpdateTenantRequest represents a request to update a tenant
+type UpdateTenantRequest struct {
+ Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=255"`
+ Status *TenantStatus `json:"status,omitempty" validate:"omitempty,oneof=active suspended inactive"`
+ Domain *string `json:"domain,omitempty" validate:"omitempty,fqdn"`
+ Description *string `json:"description,omitempty" validate:"omitempty,max=1000"`
+ Settings *TenantSettings `json:"settings,omitempty"`
+}
+
+// CreateTenantUserRequest represents a request to create a user in a tenant
+type CreateTenantUserRequest struct {
+ TenantID uuid.UUID `json:"tenant_id" validate:"required"`
+ UserID string `json:"user_id" validate:"required"`
+ Email string `json:"email" validate:"required,email"`
+ Name string `json:"name" validate:"required"`
+ Roles []string `json:"roles,omitempty"`
+ Permissions []string `json:"permissions,omitempty"`
+ Metadata map[string]string `json:"metadata,omitempty"`
+}
+
+// UpdateTenantUserRequest represents a request to update a tenant user
+type UpdateTenantUserRequest struct {
+ Email *string `json:"email,omitempty" validate:"omitempty,email"`
+ Name *string `json:"name,omitempty" validate:"omitempty,min=1"`
+ Roles []string `json:"roles,omitempty"`
+ Permissions []string `json:"permissions,omitempty"`
+ Status *UserStatus `json:"status,omitempty" validate:"omitempty,oneof=active inactive suspended"`
+ Metadata map[string]string `json:"metadata,omitempty"`
+}
+
+// CreateTenantRoleRequest represents a request to create a role in a tenant
+type CreateTenantRoleRequest struct {
+ TenantID uuid.UUID `json:"tenant_id" validate:"required"`
+ Name string `json:"name" validate:"required,min=1,max=100"`
+ Description string `json:"description,omitempty" validate:"max=500"`
+ Permissions []string `json:"permissions,omitempty"`
+}
+
+// UpdateTenantRoleRequest represents a request to update a tenant role
+type UpdateTenantRoleRequest struct {
+ Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=100"`
+ Description *string `json:"description,omitempty" validate:"omitempty,max=500"`
+ Permissions []string `json:"permissions,omitempty"`
+}
+
+// TenantListRequest represents a request to list tenants
+type TenantListRequest struct {
+ Status *TenantStatus `json:"status,omitempty"`
+ Domain string `json:"domain,omitempty"`
+ Limit int `json:"limit" validate:"min=1,max=100"`
+ Offset int `json:"offset" validate:"min=0"`
+}
+
+// TenantListResponse represents a response for listing tenants
+type TenantListResponse struct {
+ Tenants []*Tenant `json:"tenants"`
+ Total int `json:"total"`
+ Limit int `json:"limit"`
+ Offset int `json:"offset"`
+}
+
+// IsActive checks if the tenant is active
+func (t *Tenant) IsActive() bool {
+ return t.Status == TenantStatusActive
+}
+
+// IsSuspended checks if the tenant is suspended
+func (t *Tenant) IsSuspended() bool {
+ return t.Status == TenantStatusSuspended
+}
+
+// HasFeature checks if a feature is enabled for the tenant
+func (t *Tenant) HasFeature(feature string) bool {
+ if t.Settings.Features == nil {
+ return false
+ }
+ enabled, exists := t.Settings.Features[feature]
+ return exists && enabled
+}
+
+// GetAuthProvider returns the authentication provider for the tenant
+func (t *Tenant) GetAuthProvider() string {
+ if t.Settings.AuthProvider != "" {
+ return t.Settings.AuthProvider
+ }
+ return "header" // default
+}
+
+// GetSessionTimeout returns the session timeout for the tenant
+func (t *Tenant) GetSessionTimeout() time.Duration {
+ if t.Settings.SessionTimeout > 0 {
+ return t.Settings.SessionTimeout
+ }
+ return 8 * time.Hour // default
+}
+
+// GetMaxConcurrentSessions returns the maximum concurrent sessions for the tenant
+func (t *Tenant) GetMaxConcurrentSessions() int {
+ if t.Settings.MaxConcurrentSessions > 0 {
+ return t.Settings.MaxConcurrentSessions
+ }
+ return 10 // default
+}
+
+// IsActive checks if the tenant user is active
+func (tu *TenantUser) IsActive() bool {
+ return tu.Status == UserStatusActive
+}
+
+// IsSuspended checks if the tenant user is suspended
+func (tu *TenantUser) IsSuspended() bool {
+ return tu.Status == UserStatusSuspended
+}
+
+// HasRole checks if the user has a specific role
+func (tu *TenantUser) HasRole(role string) bool {
+ for _, r := range tu.Roles {
+ if r == role {
+ return true
+ }
+ }
+ return false
+}
+
+// HasPermission checks if the user has a specific permission
+func (tu *TenantUser) HasPermission(permission string) bool {
+ for _, p := range tu.Permissions {
+ if p == permission {
+ return true
+ }
+ }
+ return false
+}
+
+// UpdateLastLogin updates the last login timestamp
+func (tu *TenantUser) UpdateLastLogin() {
+ now := time.Now()
+ tu.LastLoginAt = &now
+ tu.UpdatedAt = now
+}
+
+// IsSystemRole checks if the role is a system role
+func (tr *TenantRole) IsSystemRole() bool {
+ return tr.IsSystem
+}
+
+// HasPermission checks if the role has a specific permission
+func (tr *TenantRole) HasPermission(permission string) bool {
+ for _, p := range tr.Permissions {
+ if p == permission {
+ return true
+ }
+ }
+ return false
+}
+
+// TenantContext represents the tenant context for a request
+type TenantContext struct {
+ TenantID uuid.UUID `json:"tenant_id"`
+ TenantSlug string `json:"tenant_slug"`
+ UserID string `json:"user_id"`
+ Roles []string `json:"roles"`
+ Permissions []string `json:"permissions"`
+}
+
+// MultiTenantAuthContext extends AuthContext with tenant information
+type MultiTenantAuthContext struct {
+ *AuthContext
+ TenantContext *TenantContext `json:"tenant_context,omitempty"`
+}
diff --git a/internal/errors/errors.go b/internal/errors/errors.go
index 902ee5e..a0b8567 100644
--- a/internal/errors/errors.go
+++ b/internal/errors/errors.go
@@ -295,3 +295,66 @@ func (c *Chain) Error() string {
}
return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1)
}
+
+// Helper functions to check error types
+
+// IsNotFound checks if an error is a not found error
+func IsNotFound(err error) bool {
+ if appErr, ok := err.(*AppError); ok {
+ return appErr.Code == ErrNotFound || appErr.Code == ErrApplicationNotFound ||
+ appErr.Code == ErrTokenNotFound || appErr.Code == ErrPermissionNotFound
+ }
+ return false
+}
+
+// IsValidationError checks if an error is a validation error
+func IsValidationError(err error) bool {
+ if appErr, ok := err.(*AppError); ok {
+ return appErr.Code == ErrValidationFailed || appErr.Code == ErrInvalidInput ||
+ appErr.Code == ErrMissingField || appErr.Code == ErrInvalidFormat
+ }
+ return false
+}
+
+// IsAuthenticationError checks if an error is an authentication error
+func IsAuthenticationError(err error) bool {
+ if appErr, ok := err.(*AppError); ok {
+ return appErr.Code == ErrUnauthorized || appErr.Code == ErrInvalidToken ||
+ appErr.Code == ErrTokenExpired || appErr.Code == ErrInvalidCredentials
+ }
+ return false
+}
+
+// IsAuthorizationError checks if an error is an authorization error
+func IsAuthorizationError(err error) bool {
+ if appErr, ok := err.(*AppError); ok {
+ return appErr.Code == ErrForbidden || appErr.Code == ErrInsufficientPermissions
+ }
+ return false
+}
+
+// IsConflictError checks if an error is a conflict error
+func IsConflictError(err error) bool {
+ if appErr, ok := err.(*AppError); ok {
+ return appErr.Code == ErrAlreadyExists || appErr.Code == ErrConflict
+ }
+ return false
+}
+
+// IsInternalError checks if an error is an internal server error
+func IsInternalError(err error) bool {
+ if appErr, ok := err.(*AppError); ok {
+ return appErr.Code == ErrInternal || appErr.Code == ErrDatabase ||
+ appErr.Code == ErrExternal || appErr.Code == ErrTokenCreationFailed
+ }
+ return false
+}
+
+// IsConfigurationError checks if an error is a configuration error
+func IsConfigurationError(err error) bool {
+ if appErr, ok := err.(*AppError); ok {
+ // Configuration errors are typically mapped to internal errors
+ return appErr.Code == ErrInternal && appErr.Message != ""
+ }
+ return false
+}
diff --git a/internal/handlers/saml.go b/internal/handlers/saml.go
new file mode 100644
index 0000000..860226e
--- /dev/null
+++ b/internal/handlers/saml.go
@@ -0,0 +1,352 @@
+package handlers
+
+import (
+ "encoding/json"
+ "net/http"
+ "time"
+
+ "github.com/gorilla/mux"
+ "go.uber.org/zap"
+
+ "github.com/kms/api-key-service/internal/auth"
+ "github.com/kms/api-key-service/internal/config"
+ "github.com/kms/api-key-service/internal/domain"
+ "github.com/kms/api-key-service/internal/errors"
+ "github.com/kms/api-key-service/internal/services"
+)
+
+// SAMLHandler handles SAML authentication endpoints
+type SAMLHandler struct {
+ samlProvider *auth.SAMLProvider
+ sessionService services.SessionService
+ authService services.AuthenticationService
+ tokenService services.TokenService
+ config config.ConfigProvider
+ logger *zap.Logger
+}
+
+// NewSAMLHandler creates a new SAML handler
+func NewSAMLHandler(
+ config config.ConfigProvider,
+ sessionService services.SessionService,
+ authService services.AuthenticationService,
+ tokenService services.TokenService,
+ logger *zap.Logger,
+) (*SAMLHandler, error) {
+ samlProvider, err := auth.NewSAMLProvider(config, logger)
+ if err != nil {
+ return nil, err
+ }
+
+ return &SAMLHandler{
+ samlProvider: samlProvider,
+ sessionService: sessionService,
+ authService: authService,
+ config: config,
+ logger: logger,
+ }, nil
+}
+
+// RegisterRoutes registers SAML routes
+func (h *SAMLHandler) RegisterRoutes(router *mux.Router) {
+ // SAML endpoints
+ router.HandleFunc("/auth/saml/login", h.InitiateSAMLLogin).Methods("GET")
+ router.HandleFunc("/auth/saml/acs", h.HandleSAMLResponse).Methods("POST")
+ router.HandleFunc("/auth/saml/metadata", h.GetServiceProviderMetadata).Methods("GET")
+ router.HandleFunc("/auth/saml/slo", h.HandleSingleLogout).Methods("GET", "POST")
+}
+
+// InitiateSAMLLogin initiates SAML authentication
+func (h *SAMLHandler) InitiateSAMLLogin(w http.ResponseWriter, r *http.Request) {
+ if !h.config.GetBool("SAML_ENABLED") {
+ h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
+ return
+ }
+
+ // Get query parameters
+ appID := r.URL.Query().Get("app_id")
+ redirectURL := r.URL.Query().Get("redirect_url")
+
+ if appID == "" {
+ h.writeErrorResponse(w, errors.NewValidationError("app_id parameter is required"))
+ return
+ }
+
+ // Generate relay state with app_id and redirect_url
+ relayState := appID
+ if redirectURL != "" {
+ relayState += "|" + redirectURL
+ }
+
+ h.logger.Debug("Initiating SAML login",
+ zap.String("app_id", appID),
+ zap.String("redirect_url", redirectURL))
+
+ // Generate SAML authentication request
+ authURL, requestID, err := h.samlProvider.GenerateAuthRequest(r.Context(), relayState)
+ if err != nil {
+ h.logger.Error("Failed to generate SAML auth request", zap.Error(err))
+ h.writeErrorResponse(w, err)
+ return
+ }
+
+ // Store request ID in session/cache for validation
+ // In production, you should store this securely
+ h.logger.Debug("Generated SAML auth request",
+ zap.String("request_id", requestID),
+ zap.String("auth_url", authURL))
+
+ // Redirect to IdP
+ http.Redirect(w, r, authURL, http.StatusFound)
+}
+
+// HandleSAMLResponse handles SAML assertion consumer service (ACS)
+func (h *SAMLHandler) HandleSAMLResponse(w http.ResponseWriter, r *http.Request) {
+ if !h.config.GetBool("SAML_ENABLED") {
+ h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
+ return
+ }
+
+ h.logger.Debug("Handling SAML response")
+
+ // Parse form data
+ if err := r.ParseForm(); err != nil {
+ h.writeErrorResponse(w, errors.NewValidationError("Failed to parse form data").WithInternal(err))
+ return
+ }
+
+ samlResponse := r.FormValue("SAMLResponse")
+ relayState := r.FormValue("RelayState")
+
+ if samlResponse == "" {
+ h.writeErrorResponse(w, errors.NewValidationError("SAMLResponse is required"))
+ return
+ }
+
+ h.logger.Debug("Processing SAML response", zap.String("relay_state", relayState))
+
+ // Process SAML response
+ // In production, you should retrieve and validate the original request ID
+ authContext, err := h.samlProvider.ProcessSAMLResponse(r.Context(), samlResponse, "")
+ if err != nil {
+ h.logger.Error("Failed to process SAML response", zap.Error(err))
+ h.writeErrorResponse(w, err)
+ return
+ }
+
+ // Parse relay state to get app_id and redirect_url
+ appID, redirectURL := h.parseRelayState(relayState)
+ if appID == "" {
+ h.writeErrorResponse(w, errors.NewValidationError("Invalid relay state: missing app_id"))
+ return
+ }
+
+ // Create user session
+ sessionReq := &domain.CreateSessionRequest{
+ UserID: authContext.UserID,
+ AppID: appID,
+ SessionType: domain.SessionTypeWeb,
+ IPAddress: h.getClientIP(r),
+ UserAgent: r.UserAgent(),
+ ExpiresAt: time.Now().Add(8 * time.Hour), // 8 hour session
+ Permissions: authContext.Permissions,
+ Claims: authContext.Claims,
+ }
+
+ session, err := h.sessionService.CreateSession(r.Context(), sessionReq)
+ if err != nil {
+ h.logger.Error("Failed to create session", zap.Error(err))
+ h.writeErrorResponse(w, err)
+ return
+ }
+
+ // Generate JWT token for the session using the existing token service
+ userToken := &domain.UserToken{
+ AppID: appID,
+ UserID: authContext.UserID,
+ Permissions: authContext.Permissions,
+ IssuedAt: time.Now(),
+ ExpiresAt: session.ExpiresAt,
+ MaxValidAt: session.ExpiresAt,
+ TokenType: domain.TokenTypeUser,
+ Claims: authContext.Claims,
+ }
+
+ tokenString, err := h.authService.GenerateJWTToken(r.Context(), userToken)
+ if err != nil {
+ h.logger.Error("Failed to create JWT token", zap.Error(err))
+ h.writeErrorResponse(w, err)
+ return
+ }
+
+ h.logger.Debug("SAML authentication successful",
+ zap.String("user_id", authContext.UserID),
+ zap.String("session_id", session.ID.String()))
+
+ // If redirect URL is provided, redirect with token
+ if redirectURL != "" {
+ // Add token as query parameter or fragment
+ redirectURL += "?token=" + tokenString
+ http.Redirect(w, r, redirectURL, http.StatusFound)
+ return
+ }
+
+ // Otherwise, return JSON response
+ response := map[string]interface{}{
+ "success": true,
+ "token": tokenString,
+ "user": map[string]interface{}{
+ "id": authContext.UserID,
+ "email": authContext.Claims["email"],
+ "name": authContext.Claims["name"],
+ },
+ "session_id": session.ID.String(),
+ "expires_at": session.ExpiresAt,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(response)
+}
+
+// GetServiceProviderMetadata returns SP metadata XML
+func (h *SAMLHandler) GetServiceProviderMetadata(w http.ResponseWriter, r *http.Request) {
+ if !h.config.GetBool("SAML_ENABLED") {
+ h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
+ return
+ }
+
+ h.logger.Debug("Generating SP metadata")
+
+ metadata, err := h.samlProvider.GenerateServiceProviderMetadata()
+ if err != nil {
+ h.logger.Error("Failed to generate SP metadata", zap.Error(err))
+ h.writeErrorResponse(w, err)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/xml")
+ w.Write([]byte(metadata))
+}
+
+// HandleSingleLogout handles SAML single logout
+func (h *SAMLHandler) HandleSingleLogout(w http.ResponseWriter, r *http.Request) {
+ if !h.config.GetBool("SAML_ENABLED") {
+ h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
+ return
+ }
+
+ h.logger.Debug("Handling SAML single logout")
+
+ // Get session ID from query parameter or form
+ sessionID := r.URL.Query().Get("session_id")
+ if sessionID == "" && r.Method == "POST" {
+ r.ParseForm()
+ sessionID = r.FormValue("session_id")
+ }
+
+ if sessionID != "" {
+ // Revoke specific session
+ h.logger.Debug("Revoking session", zap.String("session_id", sessionID))
+ // Implementation would depend on how you store session IDs
+ // For now, we'll just log it
+ }
+
+ // In a full implementation, you would:
+ // 1. Parse the SAML LogoutRequest
+ // 2. Validate the request
+ // 3. Revoke the user's sessions
+ // 4. Generate a LogoutResponse
+ // 5. Redirect back to the IdP
+
+ // For now, return a simple success response
+ response := map[string]interface{}{
+ "success": true,
+ "message": "Logout successful",
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(response)
+}
+
+// parseRelayState parses the relay state to extract app_id and redirect_url
+func (h *SAMLHandler) parseRelayState(relayState string) (appID, redirectURL string) {
+ if relayState == "" {
+ return "", ""
+ }
+
+ // RelayState format: "app_id|redirect_url" or just "app_id"
+ parts := []string{relayState}
+ if len(relayState) > 0 && relayState[0] != '|' {
+ // Split on first pipe character
+ for i, char := range relayState {
+ if char == '|' {
+ parts = []string{relayState[:i], relayState[i+1:]}
+ break
+ }
+ }
+ }
+
+ appID = parts[0]
+ if len(parts) > 1 {
+ redirectURL = parts[1]
+ }
+
+ return appID, redirectURL
+}
+
+// getClientIP extracts the client IP address from the request
+func (h *SAMLHandler) getClientIP(r *http.Request) string {
+ // Check X-Forwarded-For header first
+ if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
+ // Take the first IP if multiple are present
+ if idx := len(xff); idx > 0 {
+ for i, char := range xff {
+ if char == ',' {
+ return xff[:i]
+ }
+ }
+ return xff
+ }
+ }
+
+ // Check X-Real-IP header
+ if xri := r.Header.Get("X-Real-IP"); xri != "" {
+ return xri
+ }
+
+ // Fall back to RemoteAddr
+ return r.RemoteAddr
+}
+
+// writeErrorResponse writes an error response
+func (h *SAMLHandler) writeErrorResponse(w http.ResponseWriter, err error) {
+ var statusCode int
+ var errorCode string
+
+ switch {
+ case errors.IsValidationError(err):
+ statusCode = http.StatusBadRequest
+ errorCode = "VALIDATION_ERROR"
+ case errors.IsAuthenticationError(err):
+ statusCode = http.StatusUnauthorized
+ errorCode = "AUTHENTICATION_ERROR"
+ case errors.IsConfigurationError(err):
+ statusCode = http.StatusServiceUnavailable
+ errorCode = "CONFIGURATION_ERROR"
+ default:
+ statusCode = http.StatusInternalServerError
+ errorCode = "INTERNAL_ERROR"
+ }
+
+ response := map[string]interface{}{
+ "success": false,
+ "error": map[string]interface{}{
+ "code": errorCode,
+ "message": err.Error(),
+ },
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(statusCode)
+ json.NewEncoder(w).Encode(response)
+}
diff --git a/internal/middleware/security.go b/internal/middleware/security.go
index 2624b1b..f10cdc1 100644
--- a/internal/middleware/security.go
+++ b/internal/middleware/security.go
@@ -2,7 +2,6 @@ package middleware
import (
"context"
- "fmt"
"net"
"net/http"
"strings"
@@ -14,7 +13,6 @@ import (
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
- "github.com/kms/api-key-service/internal/errors"
)
// SecurityMiddleware provides various security features
@@ -408,8 +406,6 @@ func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
// GetSecurityMetrics returns security-related metrics
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
- ctx := context.Background()
-
// This is a simplified version - in production you'd want more comprehensive metrics
metrics := map[string]interface{}{
"active_rate_limiters": len(s.rateLimiters),
diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go
index 0958294..065c23c 100644
--- a/internal/repository/interfaces.go
+++ b/internal/repository/interfaces.go
@@ -104,6 +104,57 @@ type GrantedPermissionRepository interface {
HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error)
}
+// SessionRepository defines the interface for user session data operations
+type SessionRepository interface {
+ // Create creates a new user session
+ Create(ctx context.Context, session *domain.UserSession) error
+
+ // GetByID retrieves a session by its ID
+ GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
+
+ // GetByUserID retrieves all sessions for a user
+ GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
+
+ // GetByUserAndApp retrieves sessions for a specific user and application
+ GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
+
+ // GetActiveByUserID retrieves all active sessions for a user
+ GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
+
+ // List retrieves sessions with filtering and pagination
+ List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
+
+ // Update updates an existing session
+ Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
+
+ // UpdateActivity updates the last activity timestamp for a session
+ UpdateActivity(ctx context.Context, sessionID uuid.UUID) error
+
+ // Revoke revokes a session
+ Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
+
+ // RevokeAllByUser revokes all sessions for a user
+ RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error
+
+ // RevokeAllByUserAndApp revokes all sessions for a user and application
+ RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error
+
+ // ExpireOldSessions marks expired sessions as expired
+ ExpireOldSessions(ctx context.Context) (int, error)
+
+ // DeleteExpiredSessions removes expired sessions older than the specified duration
+ DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error)
+
+ // Exists checks if a session exists
+ Exists(ctx context.Context, sessionID uuid.UUID) (bool, error)
+
+ // GetSessionCount returns the total number of sessions for a user
+ GetSessionCount(ctx context.Context, userID string) (int, error)
+
+ // GetActiveSessionCount returns the number of active sessions for a user
+ GetActiveSessionCount(ctx context.Context, userID string) (int, error)
+}
+
// DatabaseProvider defines the interface for database operations
type DatabaseProvider interface {
// GetDB returns the underlying database connection
diff --git a/internal/repository/postgres/session_repository.go b/internal/repository/postgres/session_repository.go
new file mode 100644
index 0000000..bb5000f
--- /dev/null
+++ b/internal/repository/postgres/session_repository.go
@@ -0,0 +1,624 @@
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/jmoiron/sqlx"
+ "go.uber.org/zap"
+
+ "github.com/kms/api-key-service/internal/domain"
+ "github.com/kms/api-key-service/internal/errors"
+ "github.com/kms/api-key-service/internal/repository"
+)
+
+// sessionRepository implements the SessionRepository interface
+type sessionRepository struct {
+ db *sqlx.DB
+ logger *zap.Logger
+}
+
+// NewSessionRepository creates a new session repository
+func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository {
+ return &sessionRepository{
+ db: db,
+ logger: logger,
+ }
+}
+
+// Create creates a new user session
+func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error {
+ r.logger.Debug("Creating new session",
+ zap.String("user_id", session.UserID),
+ zap.String("app_id", session.AppID),
+ zap.String("session_type", string(session.SessionType)))
+
+ // Generate ID if not provided
+ if session.ID == uuid.Nil {
+ session.ID = uuid.New()
+ }
+
+ // Set timestamps
+ now := time.Now()
+ session.CreatedAt = now
+ session.UpdatedAt = now
+ session.LastActivity = now
+
+ // Serialize metadata
+ metadataJSON, err := json.Marshal(session.Metadata)
+ if err != nil {
+ return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err)
+ }
+
+ query := `
+ INSERT INTO user_sessions (
+ id, user_id, app_id, session_type, status, access_token,
+ refresh_token, id_token, ip_address, user_agent,
+ last_activity, expires_at, created_at, updated_at, metadata
+ ) VALUES (
+ $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
+ )`
+
+ _, err = r.db.ExecContext(ctx, query,
+ session.ID,
+ session.UserID,
+ session.AppID,
+ session.SessionType,
+ session.Status,
+ session.AccessToken,
+ session.RefreshToken,
+ session.IDToken,
+ session.IPAddress,
+ session.UserAgent,
+ session.LastActivity,
+ session.ExpiresAt,
+ session.CreatedAt,
+ session.UpdatedAt,
+ metadataJSON,
+ )
+
+ if err != nil {
+ r.logger.Error("Failed to create session", zap.Error(err))
+ return errors.NewInternalError("Failed to create session").WithInternal(err)
+ }
+
+ r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
+ return nil
+}
+
+// GetByID retrieves a session by its ID
+func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
+ r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String()))
+
+ query := `
+ SELECT id, user_id, app_id, session_type, status, access_token,
+ refresh_token, id_token, ip_address, user_agent,
+ last_activity, expires_at, created_at, updated_at,
+ revoked_at, revoked_by, metadata
+ FROM user_sessions
+ WHERE id = $1`
+
+ var session domain.UserSession
+ var metadataJSON []byte
+ var revokedAt sql.NullTime
+ var revokedBy sql.NullString
+
+ err := r.db.QueryRowContext(ctx, query, sessionID).Scan(
+ &session.ID,
+ &session.UserID,
+ &session.AppID,
+ &session.SessionType,
+ &session.Status,
+ &session.AccessToken,
+ &session.RefreshToken,
+ &session.IDToken,
+ &session.IPAddress,
+ &session.UserAgent,
+ &session.LastActivity,
+ &session.ExpiresAt,
+ &session.CreatedAt,
+ &session.UpdatedAt,
+ &revokedAt,
+ &revokedBy,
+ &metadataJSON,
+ )
+
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, errors.NewNotFoundError("Session not found")
+ }
+ r.logger.Error("Failed to get session by ID", zap.Error(err))
+ return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err)
+ }
+
+ // Handle nullable fields
+ if revokedAt.Valid {
+ session.RevokedAt = &revokedAt.Time
+ }
+ if revokedBy.Valid {
+ session.RevokedBy = &revokedBy.String
+ }
+
+ // Deserialize metadata
+ if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
+ r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
+ session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
+ }
+
+ r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String()))
+ return &session, nil
+}
+
+// GetByUserID retrieves all sessions for a user
+func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
+ r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID))
+
+ query := `
+ SELECT id, user_id, app_id, session_type, status, access_token,
+ refresh_token, id_token, ip_address, user_agent,
+ last_activity, expires_at, created_at, updated_at,
+ revoked_at, revoked_by, metadata
+ FROM user_sessions
+ WHERE user_id = $1
+ ORDER BY created_at DESC`
+
+ return r.scanSessions(ctx, query, userID)
+}
+
+// GetByUserAndApp retrieves sessions for a specific user and application
+func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
+ r.logger.Debug("Getting sessions by user and app",
+ zap.String("user_id", userID),
+ zap.String("app_id", appID))
+
+ query := `
+ SELECT id, user_id, app_id, session_type, status, access_token,
+ refresh_token, id_token, ip_address, user_agent,
+ last_activity, expires_at, created_at, updated_at,
+ revoked_at, revoked_by, metadata
+ FROM user_sessions
+ WHERE user_id = $1 AND app_id = $2
+ ORDER BY created_at DESC`
+
+ return r.scanSessions(ctx, query, userID, appID)
+}
+
+// GetActiveByUserID retrieves all active sessions for a user
+func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
+ r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID))
+
+ query := `
+ SELECT id, user_id, app_id, session_type, status, access_token,
+ refresh_token, id_token, ip_address, user_agent,
+ last_activity, expires_at, created_at, updated_at,
+ revoked_at, revoked_by, metadata
+ FROM user_sessions
+ WHERE user_id = $1 AND status = $2 AND expires_at > NOW()
+ ORDER BY last_activity DESC`
+
+ return r.scanSessions(ctx, query, userID, domain.SessionStatusActive)
+}
+
+// List retrieves sessions with filtering and pagination
+func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
+ r.logger.Debug("Listing sessions with filters",
+ zap.String("user_id", req.UserID),
+ zap.String("app_id", req.AppID),
+ zap.Int("limit", req.Limit),
+ zap.Int("offset", req.Offset))
+
+ // Build WHERE clause dynamically
+ whereClause := "WHERE 1=1"
+ args := []interface{}{}
+ argIndex := 1
+
+ if req.UserID != "" {
+ whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex)
+ args = append(args, req.UserID)
+ argIndex++
+ }
+
+ if req.AppID != "" {
+ whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex)
+ args = append(args, req.AppID)
+ argIndex++
+ }
+
+ if req.Status != nil {
+ whereClause += fmt.Sprintf(" AND status = $%d", argIndex)
+ args = append(args, *req.Status)
+ argIndex++
+ }
+
+ if req.SessionType != nil {
+ whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex)
+ args = append(args, *req.SessionType)
+ argIndex++
+ }
+
+ if req.TenantID != "" {
+ whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex)
+ args = append(args, req.TenantID)
+ argIndex++
+ }
+
+ // Get total count
+ countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause)
+ var total int
+ err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
+ if err != nil {
+ r.logger.Error("Failed to get session count", zap.Error(err))
+ return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err)
+ }
+
+ // Get sessions with pagination
+ query := fmt.Sprintf(`
+ SELECT id, user_id, app_id, session_type, status, access_token,
+ refresh_token, id_token, ip_address, user_agent,
+ last_activity, expires_at, created_at, updated_at,
+ revoked_at, revoked_by, metadata
+ FROM user_sessions
+ %s
+ ORDER BY created_at DESC
+ LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1)
+
+ args = append(args, req.Limit, req.Offset)
+ sessions, err := r.scanSessions(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+
+ return &domain.SessionListResponse{
+ Sessions: sessions,
+ Total: total,
+ Limit: req.Limit,
+ Offset: req.Offset,
+ }, nil
+}
+
+// Update updates an existing session
+func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
+ r.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
+
+ // Build UPDATE clause dynamically
+ setParts := []string{"updated_at = NOW()"}
+ args := []interface{}{}
+ argIndex := 1
+
+ if updates.Status != nil {
+ setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex))
+ args = append(args, *updates.Status)
+ argIndex++
+ }
+
+ if updates.LastActivity != nil {
+ setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex))
+ args = append(args, *updates.LastActivity)
+ argIndex++
+ }
+
+ if updates.ExpiresAt != nil {
+ setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex))
+ args = append(args, *updates.ExpiresAt)
+ argIndex++
+ }
+
+ if updates.IPAddress != nil {
+ setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex))
+ args = append(args, *updates.IPAddress)
+ argIndex++
+ }
+
+ if updates.UserAgent != nil {
+ setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex))
+ args = append(args, *updates.UserAgent)
+ argIndex++
+ }
+
+ if len(setParts) == 1 {
+ return errors.NewValidationError("No fields to update")
+ }
+
+ // Build the complete query
+ setClause := fmt.Sprintf("%s", setParts[0])
+ for i := 1; i < len(setParts); i++ {
+ setClause += fmt.Sprintf(", %s", setParts[i])
+ }
+
+ query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex)
+ args = append(args, sessionID)
+
+ result, err := r.db.ExecContext(ctx, query, args...)
+ if err != nil {
+ r.logger.Error("Failed to update session", zap.Error(err))
+ return errors.NewInternalError("Failed to update session").WithInternal(err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
+ }
+
+ if rowsAffected == 0 {
+ return errors.NewNotFoundError("Session not found")
+ }
+
+ r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
+ return nil
+}
+
+// UpdateActivity updates the last activity timestamp for a session
+func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error {
+ r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
+
+ query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1`
+
+ result, err := r.db.ExecContext(ctx, query, sessionID)
+ if err != nil {
+ r.logger.Error("Failed to update session activity", zap.Error(err))
+ return errors.NewInternalError("Failed to update session activity").WithInternal(err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
+ }
+
+ if rowsAffected == 0 {
+ return errors.NewNotFoundError("Session not found")
+ }
+
+ return nil
+}
+
+// Revoke revokes a session
+func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
+ r.logger.Debug("Revoking session",
+ zap.String("session_id", sessionID.String()),
+ zap.String("revoked_by", revokedBy))
+
+ query := `
+ UPDATE user_sessions
+ SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
+ WHERE id = $3`
+
+ result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID)
+ if err != nil {
+ r.logger.Error("Failed to revoke session", zap.Error(err))
+ return errors.NewInternalError("Failed to revoke session").WithInternal(err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
+ }
+
+ if rowsAffected == 0 {
+ return errors.NewNotFoundError("Session not found")
+ }
+
+ r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
+ return nil
+}
+
+// RevokeAllByUser revokes all sessions for a user
+func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error {
+ r.logger.Debug("Revoking all sessions for user",
+ zap.String("user_id", userID),
+ zap.String("revoked_by", revokedBy))
+
+ query := `
+ UPDATE user_sessions
+ SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
+ WHERE user_id = $3 AND status = $4`
+
+ result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive)
+ if err != nil {
+ r.logger.Error("Failed to revoke user sessions", zap.Error(err))
+ return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
+ }
+
+ r.logger.Debug("User sessions revoked",
+ zap.String("user_id", userID),
+ zap.Int64("sessions_revoked", rowsAffected))
+ return nil
+}
+
+// RevokeAllByUserAndApp revokes all sessions for a user and application
+func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error {
+ r.logger.Debug("Revoking all sessions for user and app",
+ zap.String("user_id", userID),
+ zap.String("app_id", appID),
+ zap.String("revoked_by", revokedBy))
+
+ query := `
+ UPDATE user_sessions
+ SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
+ WHERE user_id = $3 AND app_id = $4 AND status = $5`
+
+ result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive)
+ if err != nil {
+ r.logger.Error("Failed to revoke user app sessions", zap.Error(err))
+ return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
+ }
+
+ r.logger.Debug("User app sessions revoked",
+ zap.String("user_id", userID),
+ zap.String("app_id", appID),
+ zap.Int64("sessions_revoked", rowsAffected))
+ return nil
+}
+
+// ExpireOldSessions marks expired sessions as expired
+func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) {
+ r.logger.Debug("Expiring old sessions")
+
+ query := `
+ UPDATE user_sessions
+ SET status = $1, updated_at = NOW()
+ WHERE expires_at < NOW() AND status = $2`
+
+ result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive)
+ if err != nil {
+ r.logger.Error("Failed to expire old sessions", zap.Error(err))
+ return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
+ }
+
+ r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected))
+ return int(rowsAffected), nil
+}
+
+// DeleteExpiredSessions removes expired sessions older than the specified duration
+func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) {
+ r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan))
+
+ cutoffTime := time.Now().Add(-olderThan)
+ query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2`
+
+ result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime)
+ if err != nil {
+ r.logger.Error("Failed to delete expired sessions", zap.Error(err))
+ return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
+ }
+
+ r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected))
+ return int(rowsAffected), nil
+}
+
+// Exists checks if a session exists
+func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) {
+ r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String()))
+
+ query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)`
+
+ var exists bool
+ err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists)
+ if err != nil {
+ r.logger.Error("Failed to check session existence", zap.Error(err))
+ return false, errors.NewInternalError("Failed to check session existence").WithInternal(err)
+ }
+
+ return exists, nil
+}
+
+// GetSessionCount returns the total number of sessions for a user
+func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) {
+ r.logger.Debug("Getting session count for user", zap.String("user_id", userID))
+
+ query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1`
+
+ var count int
+ err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
+ if err != nil {
+ r.logger.Error("Failed to get session count", zap.Error(err))
+ return 0, errors.NewInternalError("Failed to get session count").WithInternal(err)
+ }
+
+ return count, nil
+}
+
+// GetActiveSessionCount returns the number of active sessions for a user
+func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) {
+ r.logger.Debug("Getting active session count for user", zap.String("user_id", userID))
+
+ query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()`
+
+ var count int
+ err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count)
+ if err != nil {
+ r.logger.Error("Failed to get active session count", zap.Error(err))
+ return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err)
+ }
+
+ return count, nil
+}
+
+// scanSessions is a helper method to scan multiple sessions from query results
+func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) {
+ rows, err := r.db.QueryContext(ctx, query, args...)
+ if err != nil {
+ r.logger.Error("Failed to execute session query", zap.Error(err))
+ return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err)
+ }
+ defer rows.Close()
+
+ var sessions []*domain.UserSession
+ for rows.Next() {
+ var session domain.UserSession
+ var metadataJSON []byte
+ var revokedAt sql.NullTime
+ var revokedBy sql.NullString
+
+ err := rows.Scan(
+ &session.ID,
+ &session.UserID,
+ &session.AppID,
+ &session.SessionType,
+ &session.Status,
+ &session.AccessToken,
+ &session.RefreshToken,
+ &session.IDToken,
+ &session.IPAddress,
+ &session.UserAgent,
+ &session.LastActivity,
+ &session.ExpiresAt,
+ &session.CreatedAt,
+ &session.UpdatedAt,
+ &revokedAt,
+ &revokedBy,
+ &metadataJSON,
+ )
+
+ if err != nil {
+ r.logger.Error("Failed to scan session row", zap.Error(err))
+ return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err)
+ }
+
+ // Handle nullable fields
+ if revokedAt.Valid {
+ session.RevokedAt = &revokedAt.Time
+ }
+ if revokedBy.Valid {
+ session.RevokedBy = &revokedBy.String
+ }
+
+ // Deserialize metadata
+ if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
+ r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
+ session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
+ }
+
+ sessions = append(sessions, &session)
+ }
+
+ if err := rows.Err(); err != nil {
+ r.logger.Error("Error iterating session rows", zap.Error(err))
+ return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err)
+ }
+
+ return sessions, nil
+}
diff --git a/internal/services/interfaces.go b/internal/services/interfaces.go
index 598cd45..c9a22b2 100644
--- a/internal/services/interfaces.go
+++ b/internal/services/interfaces.go
@@ -67,3 +67,54 @@ type AuthenticationService interface {
// RefreshJWTToken refreshes an existing JWT token
RefreshJWTToken(ctx context.Context, tokenString string, newExpiration time.Time) (string, error)
}
+
+// SessionService defines the interface for session management business logic
+type SessionService interface {
+ // CreateSession creates a new user session
+ CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error)
+
+ // GetSession retrieves a session by its ID
+ GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
+
+ // GetUserSessions retrieves all sessions for a user
+ GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error)
+
+ // GetUserAppSessions retrieves sessions for a specific user and application
+ GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
+
+ // GetActiveSessions retrieves all active sessions for a user
+ GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error)
+
+ // ListSessions retrieves sessions with filtering and pagination
+ ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
+
+ // UpdateSession updates an existing session
+ UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
+
+ // UpdateSessionActivity updates the last activity timestamp for a session
+ UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error
+
+ // RevokeSession revokes a specific session
+ RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
+
+ // RevokeUserSessions revokes all sessions for a user
+ RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error
+
+ // RevokeUserAppSessions revokes all sessions for a user and application
+ RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error
+
+ // ValidateSession validates if a session is active and valid
+ ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
+
+ // RefreshSession refreshes a session's expiration time
+ RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error
+
+ // CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
+ CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error)
+
+ // GetSessionStats returns session statistics for a user
+ GetSessionStats(ctx context.Context, userID string) (total int, active int, err error)
+
+ // CreateOAuth2Session creates a session from OAuth2 authentication flow
+ CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error)
+}
diff --git a/internal/services/session_service.go b/internal/services/session_service.go
new file mode 100644
index 0000000..2f5215c
--- /dev/null
+++ b/internal/services/session_service.go
@@ -0,0 +1,414 @@
+package services
+
+import (
+ "context"
+ "time"
+
+ "github.com/google/uuid"
+ "go.uber.org/zap"
+
+ "github.com/kms/api-key-service/internal/config"
+ "github.com/kms/api-key-service/internal/domain"
+ "github.com/kms/api-key-service/internal/errors"
+ "github.com/kms/api-key-service/internal/repository"
+)
+
+// sessionService implements the SessionService interface
+type sessionService struct {
+ sessionRepo repository.SessionRepository
+ appRepo repository.ApplicationRepository
+ config config.ConfigProvider
+ logger *zap.Logger
+}
+
+// NewSessionService creates a new session service
+func NewSessionService(
+ sessionRepo repository.SessionRepository,
+ appRepo repository.ApplicationRepository,
+ config config.ConfigProvider,
+ logger *zap.Logger,
+) SessionService {
+ return &sessionService{
+ sessionRepo: sessionRepo,
+ appRepo: appRepo,
+ config: config,
+ logger: logger,
+ }
+}
+
+// CreateSession creates a new user session
+func (s *sessionService) CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error) {
+ s.logger.Debug("Creating new session",
+ zap.String("user_id", req.UserID),
+ zap.String("app_id", req.AppID),
+ zap.String("session_type", string(req.SessionType)))
+
+ // Validate application exists
+ app, err := s.appRepo.GetByID(ctx, req.AppID)
+ if err != nil {
+ if errors.IsNotFound(err) {
+ return nil, errors.NewValidationError("Application not found")
+ }
+ return nil, err
+ }
+
+ // Check if application supports user tokens
+ supportsUser := false
+ for _, appType := range app.Type {
+ if appType == domain.ApplicationTypeUser {
+ supportsUser = true
+ break
+ }
+ }
+ if !supportsUser {
+ return nil, errors.NewValidationError("Application does not support user sessions")
+ }
+
+ // Create session object
+ session := &domain.UserSession{
+ ID: uuid.New(),
+ UserID: req.UserID,
+ AppID: req.AppID,
+ SessionType: req.SessionType,
+ Status: domain.SessionStatusActive,
+ IPAddress: req.IPAddress,
+ UserAgent: req.UserAgent,
+ ExpiresAt: req.ExpiresAt,
+ Metadata: domain.SessionMetadata{
+ TenantID: req.TenantID,
+ Permissions: req.Permissions,
+ Claims: req.Claims,
+ LoginMethod: "oauth2",
+ },
+ }
+
+ // Create session in repository
+ if err := s.sessionRepo.Create(ctx, session); err != nil {
+ s.logger.Error("Failed to create session", zap.Error(err))
+ return nil, err
+ }
+
+ s.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
+ return session, nil
+}
+
+// GetSession retrieves a session by its ID
+func (s *sessionService) GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
+ s.logger.Debug("Getting session", zap.String("session_id", sessionID.String()))
+
+ session, err := s.sessionRepo.GetByID(ctx, sessionID)
+ if err != nil {
+ return nil, err
+ }
+
+ return session, nil
+}
+
+// GetUserSessions retrieves all sessions for a user
+func (s *sessionService) GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
+ s.logger.Debug("Getting user sessions", zap.String("user_id", userID))
+
+ sessions, err := s.sessionRepo.GetByUserID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ return sessions, nil
+}
+
+// GetUserAppSessions retrieves sessions for a specific user and application
+func (s *sessionService) GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
+ s.logger.Debug("Getting user app sessions",
+ zap.String("user_id", userID),
+ zap.String("app_id", appID))
+
+ sessions, err := s.sessionRepo.GetByUserAndApp(ctx, userID, appID)
+ if err != nil {
+ return nil, err
+ }
+
+ return sessions, nil
+}
+
+// GetActiveSessions retrieves all active sessions for a user
+func (s *sessionService) GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
+ s.logger.Debug("Getting active sessions", zap.String("user_id", userID))
+
+ sessions, err := s.sessionRepo.GetActiveByUserID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ return sessions, nil
+}
+
+// ListSessions retrieves sessions with filtering and pagination
+func (s *sessionService) ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
+ s.logger.Debug("Listing sessions",
+ zap.String("user_id", req.UserID),
+ zap.String("app_id", req.AppID),
+ zap.Int("limit", req.Limit),
+ zap.Int("offset", req.Offset))
+
+ // Set default pagination if not provided
+ if req.Limit <= 0 {
+ req.Limit = 50
+ }
+ if req.Limit > 100 {
+ req.Limit = 100
+ }
+
+ response, err := s.sessionRepo.List(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
+
+// UpdateSession updates an existing session
+func (s *sessionService) UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
+ s.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
+
+ // Validate session exists
+ _, err := s.sessionRepo.GetByID(ctx, sessionID)
+ if err != nil {
+ return err
+ }
+
+ // Update session
+ if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
+ return err
+ }
+
+ s.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
+ return nil
+}
+
+// UpdateSessionActivity updates the last activity timestamp for a session
+func (s *sessionService) UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error {
+ s.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
+
+ if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// RevokeSession revokes a specific session
+func (s *sessionService) RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
+ s.logger.Debug("Revoking session",
+ zap.String("session_id", sessionID.String()),
+ zap.String("revoked_by", revokedBy))
+
+ // Validate session exists and is active
+ session, err := s.sessionRepo.GetByID(ctx, sessionID)
+ if err != nil {
+ return err
+ }
+
+ if session.Status != domain.SessionStatusActive {
+ return errors.NewValidationError("Session is not active")
+ }
+
+ // Revoke session
+ if err := s.sessionRepo.Revoke(ctx, sessionID, revokedBy); err != nil {
+ return err
+ }
+
+ s.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
+ return nil
+}
+
+// RevokeUserSessions revokes all sessions for a user
+func (s *sessionService) RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error {
+ s.logger.Debug("Revoking user sessions",
+ zap.String("user_id", userID),
+ zap.String("revoked_by", revokedBy))
+
+ if err := s.sessionRepo.RevokeAllByUser(ctx, userID, revokedBy); err != nil {
+ return err
+ }
+
+ s.logger.Debug("User sessions revoked successfully", zap.String("user_id", userID))
+ return nil
+}
+
+// RevokeUserAppSessions revokes all sessions for a user and application
+func (s *sessionService) RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error {
+ s.logger.Debug("Revoking user app sessions",
+ zap.String("user_id", userID),
+ zap.String("app_id", appID),
+ zap.String("revoked_by", revokedBy))
+
+ if err := s.sessionRepo.RevokeAllByUserAndApp(ctx, userID, appID, revokedBy); err != nil {
+ return err
+ }
+
+ s.logger.Debug("User app sessions revoked successfully",
+ zap.String("user_id", userID),
+ zap.String("app_id", appID))
+ return nil
+}
+
+// ValidateSession validates if a session is active and valid
+func (s *sessionService) ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
+ s.logger.Debug("Validating session", zap.String("session_id", sessionID.String()))
+
+ session, err := s.sessionRepo.GetByID(ctx, sessionID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check if session is active
+ if !session.IsActive() {
+ if session.IsExpired() {
+ return nil, errors.NewAuthenticationError("Session has expired")
+ }
+ if session.IsRevoked() {
+ return nil, errors.NewAuthenticationError("Session has been revoked")
+ }
+ return nil, errors.NewAuthenticationError("Session is not active")
+ }
+
+ // Update last activity
+ if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
+ s.logger.Warn("Failed to update session activity", zap.Error(err))
+ // Don't fail validation if we can't update activity
+ }
+
+ s.logger.Debug("Session validated successfully", zap.String("session_id", sessionID.String()))
+ return session, nil
+}
+
+// RefreshSession refreshes a session's expiration time
+func (s *sessionService) RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error {
+ s.logger.Debug("Refreshing session",
+ zap.String("session_id", sessionID.String()),
+ zap.Time("new_expiration", newExpiration))
+
+ // Validate session exists and is active
+ session, err := s.sessionRepo.GetByID(ctx, sessionID)
+ if err != nil {
+ return err
+ }
+
+ if !session.IsActive() {
+ return errors.NewValidationError("Cannot refresh inactive session")
+ }
+
+ // Update expiration
+ updates := &domain.UpdateSessionRequest{
+ ExpiresAt: &newExpiration,
+ }
+
+ if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
+ return err
+ }
+
+ s.logger.Debug("Session refreshed successfully", zap.String("session_id", sessionID.String()))
+ return nil
+}
+
+// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
+func (s *sessionService) CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error) {
+ s.logger.Debug("Cleaning up expired sessions")
+
+ // Mark expired sessions
+ expired, err = s.sessionRepo.ExpireOldSessions(ctx)
+ if err != nil {
+ s.logger.Error("Failed to expire old sessions", zap.Error(err))
+ return 0, 0, err
+ }
+
+ // Delete old expired sessions if requested
+ if deleteOlderThan != nil {
+ deleted, err = s.sessionRepo.DeleteExpiredSessions(ctx, *deleteOlderThan)
+ if err != nil {
+ s.logger.Error("Failed to delete expired sessions", zap.Error(err))
+ return expired, 0, err
+ }
+ }
+
+ s.logger.Debug("Session cleanup completed",
+ zap.Int("expired", expired),
+ zap.Int("deleted", deleted))
+
+ return expired, deleted, nil
+}
+
+// GetSessionStats returns session statistics for a user
+func (s *sessionService) GetSessionStats(ctx context.Context, userID string) (total int, active int, err error) {
+ s.logger.Debug("Getting session stats", zap.String("user_id", userID))
+
+ total, err = s.sessionRepo.GetSessionCount(ctx, userID)
+ if err != nil {
+ return 0, 0, err
+ }
+
+ active, err = s.sessionRepo.GetActiveSessionCount(ctx, userID)
+ if err != nil {
+ return 0, 0, err
+ }
+
+ return total, active, nil
+}
+
+// CreateOAuth2Session creates a session from OAuth2 authentication flow
+func (s *sessionService) CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error) {
+ s.logger.Debug("Creating OAuth2 session",
+ zap.String("user_id", userID),
+ zap.String("app_id", appID),
+ zap.String("session_type", string(sessionType)))
+
+ // Validate application exists
+ app, err := s.appRepo.GetByID(ctx, appID)
+ if err != nil {
+ if errors.IsNotFound(err) {
+ return nil, errors.NewValidationError("Application not found")
+ }
+ return nil, err
+ }
+
+ // Calculate expiration based on token response
+ expiresAt := time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second)
+
+ // Use application's max token duration if shorter
+ maxExpiration := time.Now().Add(app.MaxTokenDuration)
+ if expiresAt.After(maxExpiration) {
+ expiresAt = maxExpiration
+ }
+
+ // Create session object
+ session := &domain.UserSession{
+ ID: uuid.New(),
+ UserID: userID,
+ AppID: appID,
+ SessionType: sessionType,
+ Status: domain.SessionStatusActive,
+ AccessToken: tokenResponse.AccessToken, // In production, encrypt this
+ RefreshToken: tokenResponse.RefreshToken, // In production, encrypt this
+ IDToken: tokenResponse.IDToken, // In production, encrypt this
+ IPAddress: ipAddress,
+ UserAgent: userAgent,
+ ExpiresAt: expiresAt,
+ Metadata: domain.SessionMetadata{
+ LoginMethod: "oauth2",
+ Claims: map[string]string{
+ "sub": userInfo.Sub,
+ "email": userInfo.Email,
+ "name": userInfo.Name,
+ },
+ },
+ }
+
+ // Create session in repository
+ if err := s.sessionRepo.Create(ctx, session); err != nil {
+ s.logger.Error("Failed to create OAuth2 session", zap.Error(err))
+ return nil, err
+ }
+
+ s.logger.Debug("OAuth2 session created successfully", zap.String("session_id", session.ID.String()))
+ return session, nil
+}
diff --git a/migrations/002_user_sessions.down.sql b/migrations/002_user_sessions.down.sql
new file mode 100644
index 0000000..df3d4f3
--- /dev/null
+++ b/migrations/002_user_sessions.down.sql
@@ -0,0 +1,14 @@
+-- Drop user_sessions table and related objects
+DROP INDEX IF EXISTS idx_user_sessions_tenant;
+DROP INDEX IF EXISTS idx_user_sessions_metadata;
+DROP INDEX IF EXISTS idx_user_sessions_active_expires;
+DROP INDEX IF EXISTS idx_user_sessions_active;
+DROP INDEX IF EXISTS idx_user_sessions_created_at;
+DROP INDEX IF EXISTS idx_user_sessions_last_activity;
+DROP INDEX IF EXISTS idx_user_sessions_expires_at;
+DROP INDEX IF EXISTS idx_user_sessions_status;
+DROP INDEX IF EXISTS idx_user_sessions_user_app;
+DROP INDEX IF EXISTS idx_user_sessions_app_id;
+DROP INDEX IF EXISTS idx_user_sessions_user_id;
+
+DROP TABLE IF EXISTS user_sessions;
diff --git a/migrations/002_user_sessions.up.sql b/migrations/002_user_sessions.up.sql
new file mode 100644
index 0000000..5a73a6c
--- /dev/null
+++ b/migrations/002_user_sessions.up.sql
@@ -0,0 +1,60 @@
+-- Create user_sessions table for session management
+CREATE TABLE IF NOT EXISTS user_sessions (
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ user_id VARCHAR(255) NOT NULL,
+ app_id VARCHAR(255) NOT NULL,
+ session_type VARCHAR(20) NOT NULL CHECK (session_type IN ('web', 'mobile', 'api')),
+ status VARCHAR(20) NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'expired', 'revoked', 'suspended')),
+ access_token TEXT,
+ refresh_token TEXT,
+ id_token TEXT,
+ ip_address INET,
+ user_agent TEXT,
+ last_activity TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
+ expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
+ revoked_at TIMESTAMP WITH TIME ZONE,
+ revoked_by VARCHAR(255),
+ metadata JSONB DEFAULT '{}',
+
+ -- Foreign key constraint to applications table
+ CONSTRAINT fk_user_sessions_app_id FOREIGN KEY (app_id) REFERENCES applications(app_id) ON DELETE CASCADE
+);
+
+-- Create indexes for performance
+CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);
+CREATE INDEX IF NOT EXISTS idx_user_sessions_app_id ON user_sessions(app_id);
+CREATE INDEX IF NOT EXISTS idx_user_sessions_user_app ON user_sessions(user_id, app_id);
+CREATE INDEX IF NOT EXISTS idx_user_sessions_status ON user_sessions(status);
+CREATE INDEX IF NOT EXISTS idx_user_sessions_expires_at ON user_sessions(expires_at);
+CREATE INDEX IF NOT EXISTS idx_user_sessions_last_activity ON user_sessions(last_activity);
+CREATE INDEX IF NOT EXISTS idx_user_sessions_created_at ON user_sessions(created_at);
+
+-- Create partial indexes for active sessions (most common queries)
+CREATE INDEX IF NOT EXISTS idx_user_sessions_active ON user_sessions(user_id, app_id) WHERE status = 'active';
+CREATE INDEX IF NOT EXISTS idx_user_sessions_active_expires ON user_sessions(user_id, expires_at) WHERE status = 'active';
+
+-- Create GIN index for metadata JSONB queries
+CREATE INDEX IF NOT EXISTS idx_user_sessions_metadata ON user_sessions USING GIN (metadata);
+
+-- Create index for tenant-based queries (if using multi-tenancy)
+CREATE INDEX IF NOT EXISTS idx_user_sessions_tenant ON user_sessions((metadata->>'tenant_id')) WHERE metadata->>'tenant_id' IS NOT NULL;
+
+-- Add comments for documentation
+COMMENT ON TABLE user_sessions IS 'Stores user session information for authentication and session management';
+COMMENT ON COLUMN user_sessions.id IS 'Unique session identifier';
+COMMENT ON COLUMN user_sessions.user_id IS 'User identifier (email, username, or external ID)';
+COMMENT ON COLUMN user_sessions.app_id IS 'Application identifier this session belongs to';
+COMMENT ON COLUMN user_sessions.session_type IS 'Type of session: web, mobile, or api';
+COMMENT ON COLUMN user_sessions.status IS 'Current session status: active, expired, revoked, or suspended';
+COMMENT ON COLUMN user_sessions.access_token IS 'OAuth2/OIDC access token (encrypted/hashed)';
+COMMENT ON COLUMN user_sessions.refresh_token IS 'OAuth2/OIDC refresh token (encrypted/hashed)';
+COMMENT ON COLUMN user_sessions.id_token IS 'OIDC ID token (encrypted/hashed)';
+COMMENT ON COLUMN user_sessions.ip_address IS 'IP address of the client when session was created';
+COMMENT ON COLUMN user_sessions.user_agent IS 'User agent string of the client';
+COMMENT ON COLUMN user_sessions.last_activity IS 'Timestamp of last session activity';
+COMMENT ON COLUMN user_sessions.expires_at IS 'When the session expires';
+COMMENT ON COLUMN user_sessions.revoked_at IS 'When the session was revoked (if applicable)';
+COMMENT ON COLUMN user_sessions.revoked_by IS 'Who revoked the session (user ID or system)';
+COMMENT ON COLUMN user_sessions.metadata IS 'Additional session metadata (device info, location, etc.)';
diff --git a/test/saml_test.go b/test/saml_test.go
new file mode 100644
index 0000000..78cc304
--- /dev/null
+++ b/test/saml_test.go
@@ -0,0 +1,532 @@
+package test
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "regexp"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/zap/zaptest"
+
+ "github.com/kms/api-key-service/internal/auth"
+ "github.com/kms/api-key-service/internal/domain"
+)
+
+// mockSAMLMetadata returns a mock SAML IdP metadata XML
+func mockSAMLMetadata() string {
+ return `
+
+
+
+
+
+ MIICertificateData
+
+
+
+
+
+
+`
+}
+
+// mockSAMLResponse returns a mock SAML response XML with current timestamps
+func mockSAMLResponse() string {
+ now := time.Now().UTC()
+ issueInstant := now.Format(time.RFC3339)
+ notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
+ notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
+
+ return fmt.Sprintf(`
+
+ https://idp.example.com
+
+
+
+
+ https://idp.example.com
+
+ user@example.com
+
+
+
+
+
+
+ https://sp.example.com
+
+
+
+
+ user@example.com
+
+
+ Test User
+
+
+ Test
+
+
+ User
+
+
+ admin,user
+
+
+
+
+ urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport
+
+
+
+`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
+}
+
+func TestSAMLProvider_GetMetadata(t *testing.T) {
+ tests := []struct {
+ name string
+ metadataURL string
+ serverResponse string
+ serverStatus int
+ expectError bool
+ errorContains string
+ }{
+ {
+ name: "successful metadata fetch",
+ metadataURL: "https://idp.example.com/.well-known/saml-metadata",
+ serverResponse: mockSAMLMetadata(),
+ serverStatus: http.StatusOK,
+ expectError: false,
+ },
+ {
+ name: "missing metadata URL",
+ metadataURL: "",
+ expectError: true,
+ errorContains: "SAML_IDP_METADATA_URL not configured",
+ },
+ {
+ name: "server error",
+ metadataURL: "https://idp.example.com/.well-known/saml-metadata",
+ serverStatus: http.StatusInternalServerError,
+ expectError: true,
+ errorContains: "returned status 500",
+ },
+ {
+ name: "invalid XML",
+ metadataURL: "https://idp.example.com/.well-known/saml-metadata",
+ serverResponse: "invalid xml",
+ serverStatus: http.StatusOK,
+ expectError: true,
+ errorContains: "Failed to parse SAML metadata",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create mock HTTP server
+ var server *httptest.Server
+ if tt.metadataURL != "" && tt.serverStatus > 0 {
+ server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(tt.serverStatus)
+ if tt.serverResponse != "" {
+ w.Write([]byte(tt.serverResponse))
+ }
+ }))
+ defer server.Close()
+ tt.metadataURL = server.URL
+ }
+
+ // Create config
+ cfg := NewTestConfig()
+ cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
+
+ // Create SAML provider
+ logger := zaptest.NewLogger(t)
+ provider, err := auth.NewSAMLProvider(cfg, logger)
+ require.NoError(t, err)
+
+ // Test GetMetadata
+ ctx := context.Background()
+ metadata, err := provider.GetMetadata(ctx)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ if tt.errorContains != "" {
+ assert.Contains(t, err.Error(), tt.errorContains)
+ }
+ assert.Nil(t, metadata)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, metadata)
+ assert.Equal(t, "https://idp.example.com", metadata.EntityID)
+ assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
+ }
+ })
+ }
+}
+
+func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ spEntityID string
+ acsURL string
+ relayState string
+ expectError bool
+ errorContains string
+ }{
+ {
+ name: "successful auth request generation",
+ spEntityID: "https://sp.example.com",
+ acsURL: "https://sp.example.com/acs",
+ relayState: "test-relay-state",
+ },
+ {
+ name: "missing SP entity ID",
+ spEntityID: "",
+ acsURL: "https://sp.example.com/acs",
+ expectError: true,
+ errorContains: "SAML_SP_ENTITY_ID not configured",
+ },
+ {
+ name: "missing ACS URL",
+ spEntityID: "https://sp.example.com",
+ acsURL: "",
+ expectError: true,
+ errorContains: "SAML_SP_ACS_URL not configured",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create mock HTTP server for metadata
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(mockSAMLMetadata()))
+ }))
+ defer server.Close()
+
+ // Create config
+ cfg := NewTestConfig()
+ cfg.values["SAML_IDP_METADATA_URL"] = server.URL
+ cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
+ cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
+
+ // Create SAML provider
+ logger := zaptest.NewLogger(t)
+ provider, err := auth.NewSAMLProvider(cfg, logger)
+ require.NoError(t, err)
+
+ // Test GenerateAuthRequest
+ ctx := context.Background()
+ authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ if tt.errorContains != "" {
+ assert.Contains(t, err.Error(), tt.errorContains)
+ }
+ assert.Empty(t, authURL)
+ assert.Empty(t, requestID)
+ } else {
+ assert.NoError(t, err)
+ assert.NotEmpty(t, authURL)
+ assert.NotEmpty(t, requestID)
+ assert.Contains(t, authURL, "https://idp.example.com/sso")
+ assert.Contains(t, authURL, "SAMLRequest=")
+ if tt.relayState != "" {
+ assert.Contains(t, authURL, "RelayState="+tt.relayState)
+ }
+ }
+ })
+ }
+}
+
+func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
+ tests := []struct {
+ name string
+ samlResponse string
+ expectedRequestID string
+ spEntityID string
+ expectError bool
+ errorContains string
+ expectedUserID string
+ expectedEmail string
+ expectedName string
+ expectedRoles []string
+ }{
+ {
+ name: "successful SAML response processing",
+ samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
+ expectedRequestID: "_request_id",
+ spEntityID: "https://sp.example.com",
+ expectedUserID: "user@example.com",
+ expectedEmail: "user@example.com",
+ expectedName: "Test User",
+ expectedRoles: []string{"admin", "user"},
+ },
+ {
+ name: "invalid base64 encoding",
+ samlResponse: "invalid-base64",
+ expectError: true,
+ errorContains: "Failed to decode SAML response",
+ },
+ {
+ name: "invalid XML",
+ samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
+ expectError: true,
+ errorContains: "Failed to parse SAML response",
+ },
+ {
+ name: "audience mismatch",
+ samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
+ spEntityID: "https://wrong-sp.example.com",
+ expectError: true,
+ errorContains: "audience mismatch",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create config
+ cfg := NewTestConfig()
+ cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
+
+ // Create SAML provider
+ logger := zaptest.NewLogger(t)
+ provider, err := auth.NewSAMLProvider(cfg, logger)
+ require.NoError(t, err)
+
+ // Test ProcessSAMLResponse
+ ctx := context.Background()
+ authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ if tt.errorContains != "" {
+ assert.Contains(t, err.Error(), tt.errorContains)
+ }
+ assert.Nil(t, authContext)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, authContext)
+ assert.Equal(t, tt.expectedUserID, authContext.UserID)
+ assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
+
+ // Check claims
+ if tt.expectedEmail != "" {
+ assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
+ }
+ if tt.expectedName != "" {
+ assert.Equal(t, tt.expectedName, authContext.Claims["name"])
+ }
+
+ // Check permissions/roles
+ if len(tt.expectedRoles) > 0 {
+ assert.Equal(t, tt.expectedRoles, authContext.Permissions)
+ }
+ }
+ })
+ }
+}
+
+func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
+ tests := []struct {
+ name string
+ spEntityID string
+ acsURL string
+ expectError bool
+ errorContains string
+ }{
+ {
+ name: "successful SP metadata generation",
+ spEntityID: "https://sp.example.com",
+ acsURL: "https://sp.example.com/acs",
+ },
+ {
+ name: "missing SP entity ID",
+ spEntityID: "",
+ acsURL: "https://sp.example.com/acs",
+ expectError: true,
+ errorContains: "SAML_SP_ENTITY_ID not configured",
+ },
+ {
+ name: "missing ACS URL",
+ spEntityID: "https://sp.example.com",
+ acsURL: "",
+ expectError: true,
+ errorContains: "SAML_SP_ACS_URL not configured",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create config
+ cfg := NewTestConfig()
+ cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
+ cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
+
+ // Create SAML provider
+ logger := zaptest.NewLogger(t)
+ provider, err := auth.NewSAMLProvider(cfg, logger)
+ require.NoError(t, err)
+
+ // Test GenerateServiceProviderMetadata
+ metadata, err := provider.GenerateServiceProviderMetadata()
+
+ if tt.expectError {
+ assert.Error(t, err)
+ if tt.errorContains != "" {
+ assert.Contains(t, err.Error(), tt.errorContains)
+ }
+ assert.Empty(t, metadata)
+ } else {
+ assert.NoError(t, err)
+ assert.NotEmpty(t, metadata)
+ assert.Contains(t, metadata, tt.spEntityID)
+ assert.Contains(t, metadata, tt.acsURL)
+ assert.Contains(t, metadata, "EntityDescriptor")
+ assert.Contains(t, metadata, "SPSSODescriptor")
+ }
+ })
+ }
+}
+
+// Benchmark tests for SAML operations
+func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
+ // Create config
+ cfg := NewTestConfig()
+ cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
+
+ // Create SAML provider
+ logger := zaptest.NewLogger(b)
+ provider, err := auth.NewSAMLProvider(cfg, logger)
+ require.NoError(b, err)
+
+ // Prepare SAML response
+ samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
+ ctx := context.Background()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
+ // Create mock HTTP server for metadata
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(mockSAMLMetadata()))
+ }))
+ defer server.Close()
+
+ // Create config
+ cfg := NewTestConfig()
+ cfg.values["SAML_IDP_METADATA_URL"] = server.URL
+ cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
+ cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
+
+ // Create SAML provider
+ logger := zaptest.NewLogger(b)
+ provider, err := auth.NewSAMLProvider(cfg, logger)
+ require.NoError(b, err)
+
+ ctx := context.Background()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+// Test helper functions
+func TestSAMLResponseValidation(t *testing.T) {
+ // Test various SAML response validation scenarios
+ tests := []struct {
+ name string
+ modifyXML func(string) string
+ expectError bool
+ errorContains string
+ }{
+ {
+ name: "expired assertion",
+ modifyXML: func(xml string) string {
+ // Replace all NotOnOrAfter timestamps with past time
+ pastTime := "2020-01-01T13:00:00Z"
+ re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
+ return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
+ },
+ expectError: true,
+ errorContains: "assertion has expired",
+ },
+ {
+ name: "assertion not yet valid",
+ modifyXML: func(xml string) string {
+ // Replace all NotBefore timestamps with future time
+ futureTime := "2030-01-01T11:55:00Z"
+ re := regexp.MustCompile(`NotBefore="[^"]*"`)
+ return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
+ },
+ expectError: true,
+ errorContains: "assertion not yet valid",
+ },
+ {
+ name: "failed status",
+ modifyXML: func(xml string) string {
+ return strings.ReplaceAll(xml,
+ "urn:oasis:names:tc:SAML:2.0:status:Success",
+ "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
+ },
+ expectError: true,
+ errorContains: "SAML authentication failed",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create config
+ cfg := NewTestConfig()
+ cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
+
+ // Create SAML provider
+ logger := zaptest.NewLogger(t)
+ provider, err := auth.NewSAMLProvider(cfg, logger)
+ require.NoError(t, err)
+
+ // Modify SAML response
+ modifiedXML := tt.modifyXML(mockSAMLResponse())
+ samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
+
+ // Test ProcessSAMLResponse
+ ctx := context.Background()
+ authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
+
+ if tt.expectError {
+ assert.Error(t, err)
+ if tt.errorContains != "" {
+ assert.Contains(t, err.Error(), tt.errorContains)
+ }
+ assert.Nil(t, authContext)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, authContext)
+ }
+ })
+ }
+}
diff --git a/test/token_repository_test.go b/test/token_repository_test.go
new file mode 100644
index 0000000..d9c65fb
--- /dev/null
+++ b/test/token_repository_test.go
@@ -0,0 +1,705 @@
+package test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/kms/api-key-service/internal/domain"
+ "github.com/kms/api-key-service/internal/repository"
+ "github.com/kms/api-key-service/internal/repository/postgres"
+)
+
+// SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing
+type SQLMockDatabaseProvider struct {
+ db *sql.DB
+}
+
+func (m *SQLMockDatabaseProvider) GetDB() interface{} {
+ return m.db
+}
+
+func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error {
+ return m.db.PingContext(ctx)
+}
+
+func (m *SQLMockDatabaseProvider) Close() error {
+ return m.db.Close()
+}
+
+func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
+ tx, err := m.db.BeginTx(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ return &SQLMockTransactionProvider{tx: tx}, nil
+}
+
+func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error {
+ return nil
+}
+
+// SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing
+type SQLMockTransactionProvider struct {
+ tx *sql.Tx
+}
+
+func (m *SQLMockTransactionProvider) Commit() error {
+ return m.tx.Commit()
+}
+
+func (m *SQLMockTransactionProvider) Rollback() error {
+ return m.tx.Rollback()
+}
+
+func (m *SQLMockTransactionProvider) GetTx() interface{} {
+ return m.tx
+}
+
+func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+
+ mockDB := &SQLMockDatabaseProvider{db: db}
+ repo := postgres.NewStaticTokenRepository(mockDB)
+
+ cleanup := func() {
+ db.Close()
+ }
+
+ return repo.(*postgres.StaticTokenRepository), mock, cleanup
+}
+
+func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ mockDB := &SQLMockDatabaseProvider{db: db}
+ repo := postgres.NewStaticTokenRepository(mockDB)
+
+ cleanup := func() {
+ db.Close()
+ }
+
+ return repo.(*postgres.StaticTokenRepository), mock, cleanup
+}
+
+func TestStaticTokenRepository_Create(t *testing.T) {
+ tests := []struct {
+ name string
+ token *domain.StaticToken
+ setupMock func(sqlmock.Sqlmock)
+ expectError bool
+ errorMsg string
+ }{
+ {
+ name: "successful creation",
+ token: &domain.StaticToken{
+ ID: uuid.New(),
+ AppID: "test-app",
+ Owner: domain.Owner{
+ Type: domain.OwnerTypeIndividual,
+ Name: "test-user",
+ Owner: "test-owner",
+ },
+ KeyHash: "test-hash",
+ Type: "hmac",
+ },
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectExec(`INSERT INTO static_tokens`).
+ WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ },
+ expectError: false,
+ },
+ {
+ name: "database error",
+ token: &domain.StaticToken{
+ ID: uuid.New(),
+ AppID: "test-app",
+ Owner: domain.Owner{
+ Type: domain.OwnerTypeIndividual,
+ Name: "test-user",
+ Owner: "test-owner",
+ },
+ KeyHash: "test-hash",
+ Type: "hmac",
+ },
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectExec(`INSERT INTO static_tokens`).
+ WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
+ WillReturnError(sql.ErrConnDone)
+ },
+ expectError: true,
+ errorMsg: "failed to create static token",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo, mock, cleanup := setupTokenRepositoryTest(t)
+ defer cleanup()
+
+ tt.setupMock(mock)
+
+ ctx := context.Background()
+ err := repo.Create(ctx, tt.token)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ assert.NotZero(t, tt.token.CreatedAt)
+ assert.NotZero(t, tt.token.UpdatedAt)
+ }
+
+ assert.NoError(t, mock.ExpectationsWereMet())
+ })
+ }
+}
+
+func TestStaticTokenRepository_GetByID(t *testing.T) {
+ tokenID := uuid.New()
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ tokenID uuid.UUID
+ setupMock func(sqlmock.Sqlmock)
+ expectError bool
+ errorMsg string
+ expectedToken *domain.StaticToken
+ }{
+ {
+ name: "successful retrieval",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ rows := sqlmock.NewRows([]string{
+ "id", "app_id", "owner_type", "owner_name", "owner_owner",
+ "key_hash", "type", "created_at", "updated_at",
+ }).AddRow(
+ tokenID, "test-app", "individual", "test-user", "test-owner",
+ "test-hash", "user", now, now,
+ )
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnRows(rows)
+ },
+ expectError: false,
+ expectedToken: &domain.StaticToken{
+ ID: tokenID,
+ AppID: "test-app",
+ Owner: domain.Owner{
+ Type: domain.OwnerTypeIndividual,
+ Name: "test-user",
+ Owner: "test-owner",
+ },
+ KeyHash: "test-hash",
+ Type: string(domain.TokenTypeUser),
+ CreatedAt: now,
+ UpdatedAt: now,
+ },
+ },
+ {
+ name: "token not found",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnError(sql.ErrNoRows)
+ },
+ expectError: true,
+ errorMsg: "not found",
+ },
+ {
+ name: "database error",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnError(sql.ErrConnDone)
+ },
+ expectError: true,
+ errorMsg: "failed to get static token",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo, mock, cleanup := setupTokenRepositoryTest(t)
+ defer cleanup()
+
+ tt.setupMock(mock)
+
+ ctx := context.Background()
+ token, err := repo.GetByID(ctx, tt.tokenID)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ assert.Nil(t, token)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, token)
+ assert.Equal(t, tt.expectedToken.ID, token.ID)
+ assert.Equal(t, tt.expectedToken.AppID, token.AppID)
+ assert.Equal(t, tt.expectedToken.Owner, token.Owner)
+ assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
+ assert.Equal(t, tt.expectedToken.Type, token.Type)
+ }
+
+ assert.NoError(t, mock.ExpectationsWereMet())
+ })
+ }
+}
+
+func TestStaticTokenRepository_GetByKeyHash(t *testing.T) {
+ tokenID := uuid.New()
+ now := time.Now()
+ keyHash := "test-hash"
+
+ tests := []struct {
+ name string
+ keyHash string
+ setupMock func(sqlmock.Sqlmock)
+ expectError bool
+ errorMsg string
+ expectedToken *domain.StaticToken
+ }{
+ {
+ name: "successful retrieval",
+ keyHash: keyHash,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ rows := sqlmock.NewRows([]string{
+ "id", "app_id", "owner_type", "owner_name", "owner_owner",
+ "key_hash", "type", "created_at", "updated_at",
+ }).AddRow(
+ tokenID, "test-app", "individual", "test-user", "test-owner",
+ keyHash, "user", now, now,
+ )
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
+ WithArgs(keyHash).
+ WillReturnRows(rows)
+ },
+ expectError: false,
+ expectedToken: &domain.StaticToken{
+ ID: tokenID,
+ AppID: "test-app",
+ Owner: domain.Owner{
+ Type: domain.OwnerTypeIndividual,
+ Name: "test-user",
+ Owner: "test-owner",
+ },
+ KeyHash: keyHash,
+ Type: string(domain.TokenTypeUser),
+ CreatedAt: now,
+ UpdatedAt: now,
+ },
+ },
+ {
+ name: "token not found",
+ keyHash: keyHash,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
+ WithArgs(keyHash).
+ WillReturnError(sql.ErrNoRows)
+ },
+ expectError: true,
+ errorMsg: "not found",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo, mock, cleanup := setupTokenRepositoryTest(t)
+ defer cleanup()
+
+ tt.setupMock(mock)
+
+ ctx := context.Background()
+ token, err := repo.GetByKeyHash(ctx, tt.keyHash)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ assert.Nil(t, token)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, token)
+ assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
+ }
+
+ assert.NoError(t, mock.ExpectationsWereMet())
+ })
+ }
+}
+
+func TestStaticTokenRepository_GetByAppID(t *testing.T) {
+ tokenID1 := uuid.New()
+ tokenID2 := uuid.New()
+ now := time.Now()
+ appID := "test-app"
+
+ tests := []struct {
+ name string
+ appID string
+ setupMock func(sqlmock.Sqlmock)
+ expectError bool
+ errorMsg string
+ expectedCount int
+ }{
+ {
+ name: "successful retrieval with multiple tokens",
+ appID: appID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ rows := sqlmock.NewRows([]string{
+ "id", "app_id", "owner_type", "owner_name", "owner_owner",
+ "key_hash", "type", "created_at", "updated_at",
+ }).AddRow(
+ tokenID1, appID, "user", "test-user1", "test-owner1",
+ "test-hash1", "user", now, now,
+ ).AddRow(
+ tokenID2, appID, "user", "test-user2", "test-owner2",
+ "test-hash2", "user", now, now,
+ )
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
+ WithArgs(appID).
+ WillReturnRows(rows)
+ },
+ expectError: false,
+ expectedCount: 2,
+ },
+ {
+ name: "no tokens found",
+ appID: appID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ rows := sqlmock.NewRows([]string{
+ "id", "app_id", "owner_type", "owner_name", "owner_owner",
+ "key_hash", "type", "created_at", "updated_at",
+ })
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
+ WithArgs(appID).
+ WillReturnRows(rows)
+ },
+ expectError: false,
+ expectedCount: 0,
+ },
+ {
+ name: "database error",
+ appID: appID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
+ WithArgs(appID).
+ WillReturnError(sql.ErrConnDone)
+ },
+ expectError: true,
+ errorMsg: "failed to query static tokens",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo, mock, cleanup := setupTokenRepositoryTest(t)
+ defer cleanup()
+
+ tt.setupMock(mock)
+
+ ctx := context.Background()
+ tokens, err := repo.GetByAppID(ctx, tt.appID)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ assert.Nil(t, tokens)
+ } else {
+ assert.NoError(t, err)
+ assert.Len(t, tokens, tt.expectedCount)
+ }
+
+ assert.NoError(t, mock.ExpectationsWereMet())
+ })
+ }
+}
+
+func TestStaticTokenRepository_List(t *testing.T) {
+ tokenID := uuid.New()
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ limit int
+ offset int
+ setupMock func(sqlmock.Sqlmock)
+ expectError bool
+ errorMsg string
+ expectedCount int
+ }{
+ {
+ name: "successful list with pagination",
+ limit: 10,
+ offset: 0,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ rows := sqlmock.NewRows([]string{
+ "id", "app_id", "owner_type", "owner_name", "owner_owner",
+ "key_hash", "type", "created_at", "updated_at",
+ }).AddRow(
+ tokenID, "test-app", "user", "test-user", "test-owner",
+ "test-hash", "user", now, now,
+ )
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
+ WithArgs(10, 0).
+ WillReturnRows(rows)
+ },
+ expectError: false,
+ expectedCount: 1,
+ },
+ {
+ name: "database error",
+ limit: 10,
+ offset: 0,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
+ WithArgs(10, 0).
+ WillReturnError(sql.ErrConnDone)
+ },
+ expectError: true,
+ errorMsg: "failed to query static tokens",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo, mock, cleanup := setupTokenRepositoryTest(t)
+ defer cleanup()
+
+ tt.setupMock(mock)
+
+ ctx := context.Background()
+ tokens, err := repo.List(ctx, tt.limit, tt.offset)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ assert.Nil(t, tokens)
+ } else {
+ assert.NoError(t, err)
+ assert.Len(t, tokens, tt.expectedCount)
+ }
+
+ assert.NoError(t, mock.ExpectationsWereMet())
+ })
+ }
+}
+
+func TestStaticTokenRepository_Delete(t *testing.T) {
+ tokenID := uuid.New()
+
+ tests := []struct {
+ name string
+ tokenID uuid.UUID
+ setupMock func(sqlmock.Sqlmock)
+ expectError bool
+ errorMsg string
+ }{
+ {
+ name: "successful deletion",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+ },
+ expectError: false,
+ },
+ {
+ name: "token not found",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ },
+ expectError: true,
+ errorMsg: "not found",
+ },
+ {
+ name: "database error",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnError(sql.ErrConnDone)
+ },
+ expectError: true,
+ errorMsg: "failed to delete static token",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo, mock, cleanup := setupTokenRepositoryTest(t)
+ defer cleanup()
+
+ tt.setupMock(mock)
+
+ ctx := context.Background()
+ err := repo.Delete(ctx, tt.tokenID)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.NoError(t, mock.ExpectationsWereMet())
+ })
+ }
+}
+
+func TestStaticTokenRepository_Exists(t *testing.T) {
+ tokenID := uuid.New()
+
+ tests := []struct {
+ name string
+ tokenID uuid.UUID
+ setupMock func(sqlmock.Sqlmock)
+ expectError bool
+ errorMsg string
+ expectedExists bool
+ }{
+ {
+ name: "token exists",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ rows := sqlmock.NewRows([]string{"exists"}).AddRow(1)
+ mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnRows(rows)
+ },
+ expectError: false,
+ expectedExists: true,
+ },
+ {
+ name: "token does not exist",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnError(sql.ErrNoRows)
+ },
+ expectError: false,
+ expectedExists: false,
+ },
+ {
+ name: "database error",
+ tokenID: tokenID,
+ setupMock: func(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnError(sql.ErrConnDone)
+ },
+ expectError: true,
+ errorMsg: "failed to check static token existence",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo, mock, cleanup := setupTokenRepositoryTest(t)
+ defer cleanup()
+
+ tt.setupMock(mock)
+
+ ctx := context.Background()
+ exists, err := repo.Exists(ctx, tt.tokenID)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expectedExists, exists)
+ }
+
+ assert.NoError(t, mock.ExpectationsWereMet())
+ })
+ }
+}
+
+// Benchmark tests for repository operations
+func BenchmarkStaticTokenRepository_Create(b *testing.B) {
+ repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
+ defer cleanup()
+
+ token := &domain.StaticToken{
+ ID: uuid.New(),
+ AppID: "test-app",
+ Owner: domain.Owner{
+ Type: domain.OwnerTypeIndividual,
+ Name: "test-user",
+ Owner: "test-owner",
+ },
+ KeyHash: "test-hash",
+ Type: string(domain.TokenTypeUser),
+ }
+
+ // Setup mock expectations for all iterations
+ for i := 0; i < b.N; i++ {
+ mock.ExpectExec(`INSERT INTO static_tokens`).
+ WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ }
+
+ ctx := context.Background()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ token.ID = uuid.New() // Generate new ID for each iteration
+ err := repo.Create(ctx, token)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkStaticTokenRepository_GetByID(b *testing.B) {
+ repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
+ defer cleanup()
+
+ tokenID := uuid.New()
+ now := time.Now()
+
+ // Setup mock expectations for all iterations
+ for i := 0; i < b.N; i++ {
+ rows := sqlmock.NewRows([]string{
+ "id", "app_id", "owner_type", "owner_name", "owner_owner",
+ "key_hash", "type", "created_at", "updated_at",
+ }).AddRow(
+ tokenID, "test-app", "user", "test-user", "test-owner",
+ "test-hash", "user", now, now,
+ )
+ mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
+ WithArgs(tokenID).
+ WillReturnRows(rows)
+ }
+
+ ctx := context.Background()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ _, err := repo.GetByID(ctx, tokenID)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}