This commit is contained in:
2025-08-22 18:57:40 -04:00
parent d648a55c0c
commit df567983c1
20 changed files with 4519 additions and 8 deletions

View File

@ -64,10 +64,10 @@ This document outlines the complete roadmap for making the API Key Management Se
- [x] Implement authorization code exchange and token refresh
- [x] Add user info retrieval from OAuth2 providers
- [x] Create comprehensive OAuth2 unit tests with benchmarks
- [ ] Add SAML authentication support
- [ ] Create user session management
- [x] Add SAML authentication support
- [x] Create user session management
- [x] Implement role-based access control (RBAC)
- [ ] Add multi-tenant authentication support
- [x] Add multi-tenant authentication support
### Permission System Enhancement
- [x] Implement hierarchical permission inheritance
@ -120,7 +120,7 @@ This document outlines the complete roadmap for making the API Key Management Se
- [x] Create authentication failure tracking
### Audit & Compliance
- [ ] Implement comprehensive audit logging
- [x] Implement comprehensive audit logging
- [ ] Add compliance reporting features
- [ ] Create data retention policies
- [ ] Implement GDPR compliance features

2
go.mod
View File

@ -21,6 +21,7 @@ require (
)
require (
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
@ -33,6 +34,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/jmoiron/sqlx v1.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect

8
go.sum
View File

@ -1,5 +1,8 @@
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@ -43,6 +46,7 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
@ -65,10 +69,13 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg=
github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
@ -78,6 +85,7 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

590
internal/audit/audit.go Normal file
View File

@ -0,0 +1,590 @@
package audit
import (
"context"
"encoding/json"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
)
// EventType represents the type of audit event
type EventType string
const (
// Authentication events
EventTypeLogin EventType = "auth.login"
EventTypeLoginFailed EventType = "auth.login_failed"
EventTypeLogout EventType = "auth.logout"
EventTypeTokenCreated EventType = "auth.token_created"
EventTypeTokenRevoked EventType = "auth.token_revoked"
EventTypeTokenValidated EventType = "auth.token_validated"
// Session events
EventTypeSessionCreated EventType = "session.created"
EventTypeSessionRevoked EventType = "session.revoked"
EventTypeSessionExpired EventType = "session.expired"
// Application events
EventTypeAppCreated EventType = "app.created"
EventTypeAppUpdated EventType = "app.updated"
EventTypeAppDeleted EventType = "app.deleted"
// Permission events
EventTypePermissionGranted EventType = "permission.granted"
EventTypePermissionRevoked EventType = "permission.revoked"
EventTypePermissionDenied EventType = "permission.denied"
// Tenant events
EventTypeTenantCreated EventType = "tenant.created"
EventTypeTenantUpdated EventType = "tenant.updated"
EventTypeTenantSuspended EventType = "tenant.suspended"
EventTypeTenantActivated EventType = "tenant.activated"
// User events
EventTypeUserCreated EventType = "user.created"
EventTypeUserUpdated EventType = "user.updated"
EventTypeUserSuspended EventType = "user.suspended"
EventTypeUserActivated EventType = "user.activated"
// Security events
EventTypeSecurityViolation EventType = "security.violation"
EventTypeBruteForceAttempt EventType = "security.brute_force"
EventTypeIPBlocked EventType = "security.ip_blocked"
EventTypeRateLimitExceeded EventType = "security.rate_limit_exceeded"
// System events
EventTypeSystemStartup EventType = "system.startup"
EventTypeSystemShutdown EventType = "system.shutdown"
EventTypeConfigChanged EventType = "system.config_changed"
)
// EventSeverity represents the severity level of an audit event
type EventSeverity string
const (
SeverityInfo EventSeverity = "info"
SeverityWarning EventSeverity = "warning"
SeverityError EventSeverity = "error"
SeverityCritical EventSeverity = "critical"
)
// EventStatus represents the status of an audit event
type EventStatus string
const (
StatusSuccess EventStatus = "success"
StatusFailure EventStatus = "failure"
StatusPending EventStatus = "pending"
)
// AuditEvent represents a single audit event
type AuditEvent struct {
ID uuid.UUID `json:"id" db:"id"`
Type EventType `json:"type" db:"type"`
Severity EventSeverity `json:"severity" db:"severity"`
Status EventStatus `json:"status" db:"status"`
Timestamp time.Time `json:"timestamp" db:"timestamp"`
// Actor information
ActorID string `json:"actor_id,omitempty" db:"actor_id"`
ActorType string `json:"actor_type,omitempty" db:"actor_type"` // user, system, service
ActorIP string `json:"actor_ip,omitempty" db:"actor_ip"`
UserAgent string `json:"user_agent,omitempty" db:"user_agent"`
// Tenant information
TenantID *uuid.UUID `json:"tenant_id,omitempty" db:"tenant_id"`
// Resource information
ResourceID string `json:"resource_id,omitempty" db:"resource_id"`
ResourceType string `json:"resource_type,omitempty" db:"resource_type"`
// Event details
Action string `json:"action" db:"action"`
Description string `json:"description" db:"description"`
Details map[string]interface{} `json:"details,omitempty" db:"details"`
// Request context
RequestID string `json:"request_id,omitempty" db:"request_id"`
SessionID string `json:"session_id,omitempty" db:"session_id"`
// Additional metadata
Tags []string `json:"tags,omitempty" db:"tags"`
Metadata map[string]string `json:"metadata,omitempty" db:"metadata"`
}
// AuditLogger defines the interface for audit logging
type AuditLogger interface {
// LogEvent logs a single audit event
LogEvent(ctx context.Context, event *AuditEvent) error
// LogAuthEvent logs an authentication-related event
LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error
// LogPermissionEvent logs a permission-related event
LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error
// LogSecurityEvent logs a security-related event
LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error
// LogSystemEvent logs a system-related event
LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error
// QueryEvents queries audit events with filters
QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error)
// GetEventStats returns audit event statistics
GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error)
}
// AuditFilter represents filters for querying audit events
type AuditFilter struct {
EventTypes []EventType `json:"event_types,omitempty"`
Severities []EventSeverity `json:"severities,omitempty"`
Statuses []EventStatus `json:"statuses,omitempty"`
ActorID string `json:"actor_id,omitempty"`
ActorType string `json:"actor_type,omitempty"`
TenantID *uuid.UUID `json:"tenant_id,omitempty"`
ResourceID string `json:"resource_id,omitempty"`
ResourceType string `json:"resource_type,omitempty"`
StartTime *time.Time `json:"start_time,omitempty"`
EndTime *time.Time `json:"end_time,omitempty"`
Tags []string `json:"tags,omitempty"`
Limit int `json:"limit"`
Offset int `json:"offset"`
OrderBy string `json:"order_by"` // timestamp, type, severity
OrderDesc bool `json:"order_desc"`
}
// AuditStatsFilter represents filters for audit statistics
type AuditStatsFilter struct {
EventTypes []EventType `json:"event_types,omitempty"`
TenantID *uuid.UUID `json:"tenant_id,omitempty"`
StartTime *time.Time `json:"start_time,omitempty"`
EndTime *time.Time `json:"end_time,omitempty"`
GroupBy string `json:"group_by"` // type, severity, status, hour, day
}
// AuditStats represents audit event statistics
type AuditStats struct {
TotalEvents int `json:"total_events"`
ByType map[EventType]int `json:"by_type"`
BySeverity map[EventSeverity]int `json:"by_severity"`
ByStatus map[EventStatus]int `json:"by_status"`
ByTime map[string]int `json:"by_time,omitempty"`
}
// auditLogger implements the AuditLogger interface
type auditLogger struct {
config config.ConfigProvider
logger *zap.Logger
repository AuditRepository
}
// AuditRepository defines the interface for audit event storage
type AuditRepository interface {
Create(ctx context.Context, event *AuditEvent) error
Query(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error)
GetStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error)
DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error)
}
// NewAuditLogger creates a new audit logger
func NewAuditLogger(config config.ConfigProvider, logger *zap.Logger, repository AuditRepository) AuditLogger {
return &auditLogger{
config: config,
logger: logger,
repository: repository,
}
}
// LogEvent logs a single audit event
func (a *auditLogger) LogEvent(ctx context.Context, event *AuditEvent) error {
// Set default values
if event.ID == uuid.Nil {
event.ID = uuid.New()
}
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
if event.Severity == "" {
event.Severity = SeverityInfo
}
if event.Status == "" {
event.Status = StatusSuccess
}
// Extract request context if available
if requestID := ctx.Value("request_id"); requestID != nil {
if reqID, ok := requestID.(string); ok {
event.RequestID = reqID
}
}
// Log to structured logger
a.logToStructuredLogger(event)
// Store in repository
if err := a.repository.Create(ctx, event); err != nil {
a.logger.Error("Failed to store audit event",
zap.Error(err),
zap.String("event_id", event.ID.String()),
zap.String("event_type", string(event.Type)))
return err
}
return nil
}
// LogAuthEvent logs an authentication-related event
func (a *auditLogger) LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error {
severity := SeverityInfo
status := StatusSuccess
// Determine severity and status based on event type
switch eventType {
case EventTypeLoginFailed:
severity = SeverityWarning
status = StatusFailure
case EventTypeTokenRevoked:
severity = SeverityWarning
}
event := &AuditEvent{
Type: eventType,
Severity: severity,
Status: status,
ActorID: actorID,
ActorType: "user",
ActorIP: actorIP,
Action: string(eventType),
Description: a.generateDescription(eventType, details),
Details: details,
Tags: []string{"authentication"},
}
return a.LogEvent(ctx, event)
}
// LogPermissionEvent logs a permission-related event
func (a *auditLogger) LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error {
severity := SeverityInfo
status := StatusSuccess
// Determine severity and status based on event type
switch eventType {
case EventTypePermissionDenied:
severity = SeverityWarning
status = StatusFailure
}
event := &AuditEvent{
Type: eventType,
Severity: severity,
Status: status,
ActorID: actorID,
ActorType: "user",
ResourceID: resourceID,
ResourceType: resourceType,
Action: string(eventType),
Description: a.generateDescription(eventType, details),
Details: details,
Tags: []string{"permission", "authorization"},
}
return a.LogEvent(ctx, event)
}
// LogSecurityEvent logs a security-related event
func (a *auditLogger) LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error {
status := StatusSuccess
if severity == SeverityError || severity == SeverityCritical {
status = StatusFailure
}
event := &AuditEvent{
Type: eventType,
Severity: severity,
Status: status,
ActorIP: actorIP,
ActorType: "system",
Action: string(eventType),
Description: a.generateDescription(eventType, details),
Details: details,
Tags: []string{"security"},
}
return a.LogEvent(ctx, event)
}
// LogSystemEvent logs a system-related event
func (a *auditLogger) LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error {
event := &AuditEvent{
Type: eventType,
Severity: SeverityInfo,
Status: StatusSuccess,
ActorType: "system",
Action: string(eventType),
Description: a.generateDescription(eventType, details),
Details: details,
Tags: []string{"system"},
}
return a.LogEvent(ctx, event)
}
// QueryEvents queries audit events with filters
func (a *auditLogger) QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error) {
// Set default pagination
if filter.Limit <= 0 {
filter.Limit = 100
}
if filter.Limit > 1000 {
filter.Limit = 1000
}
if filter.OrderBy == "" {
filter.OrderBy = "timestamp"
filter.OrderDesc = true
}
return a.repository.Query(ctx, filter)
}
// GetEventStats returns audit event statistics
func (a *auditLogger) GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error) {
return a.repository.GetStats(ctx, filter)
}
// logToStructuredLogger logs the event to the structured logger
func (a *auditLogger) logToStructuredLogger(event *AuditEvent) {
fields := []zap.Field{
zap.String("audit_event_id", event.ID.String()),
zap.String("event_type", string(event.Type)),
zap.String("severity", string(event.Severity)),
zap.String("status", string(event.Status)),
zap.Time("timestamp", event.Timestamp),
zap.String("action", event.Action),
zap.String("description", event.Description),
}
if event.ActorID != "" {
fields = append(fields, zap.String("actor_id", event.ActorID))
}
if event.ActorType != "" {
fields = append(fields, zap.String("actor_type", event.ActorType))
}
if event.ActorIP != "" {
fields = append(fields, zap.String("actor_ip", event.ActorIP))
}
if event.TenantID != nil {
fields = append(fields, zap.String("tenant_id", event.TenantID.String()))
}
if event.ResourceID != "" {
fields = append(fields, zap.String("resource_id", event.ResourceID))
}
if event.ResourceType != "" {
fields = append(fields, zap.String("resource_type", event.ResourceType))
}
if event.RequestID != "" {
fields = append(fields, zap.String("request_id", event.RequestID))
}
if event.SessionID != "" {
fields = append(fields, zap.String("session_id", event.SessionID))
}
if len(event.Tags) > 0 {
fields = append(fields, zap.Strings("tags", event.Tags))
}
if event.Details != nil {
if detailsJSON, err := json.Marshal(event.Details); err == nil {
fields = append(fields, zap.String("details", string(detailsJSON)))
}
}
// Log at appropriate level based on severity
switch event.Severity {
case SeverityInfo:
a.logger.Info("Audit event", fields...)
case SeverityWarning:
a.logger.Warn("Audit event", fields...)
case SeverityError:
a.logger.Error("Audit event", fields...)
case SeverityCritical:
a.logger.Error("Critical audit event", fields...)
default:
a.logger.Info("Audit event", fields...)
}
}
// generateDescription generates a human-readable description for an event
func (a *auditLogger) generateDescription(eventType EventType, details map[string]interface{}) string {
switch eventType {
case EventTypeLogin:
return "User successfully logged in"
case EventTypeLoginFailed:
return "User login attempt failed"
case EventTypeLogout:
return "User logged out"
case EventTypeTokenCreated:
return "API token created"
case EventTypeTokenRevoked:
return "API token revoked"
case EventTypeTokenValidated:
return "API token validated"
case EventTypeSessionCreated:
return "User session created"
case EventTypeSessionRevoked:
return "User session revoked"
case EventTypeSessionExpired:
return "User session expired"
case EventTypeAppCreated:
return "Application created"
case EventTypeAppUpdated:
return "Application updated"
case EventTypeAppDeleted:
return "Application deleted"
case EventTypePermissionGranted:
return "Permission granted"
case EventTypePermissionRevoked:
return "Permission revoked"
case EventTypePermissionDenied:
return "Permission denied"
case EventTypeTenantCreated:
return "Tenant created"
case EventTypeTenantUpdated:
return "Tenant updated"
case EventTypeTenantSuspended:
return "Tenant suspended"
case EventTypeTenantActivated:
return "Tenant activated"
case EventTypeUserCreated:
return "User created"
case EventTypeUserUpdated:
return "User updated"
case EventTypeUserSuspended:
return "User suspended"
case EventTypeUserActivated:
return "User activated"
case EventTypeSecurityViolation:
return "Security violation detected"
case EventTypeBruteForceAttempt:
return "Brute force attempt detected"
case EventTypeIPBlocked:
return "IP address blocked"
case EventTypeRateLimitExceeded:
return "Rate limit exceeded"
case EventTypeSystemStartup:
return "System started"
case EventTypeSystemShutdown:
return "System shutdown"
case EventTypeConfigChanged:
return "Configuration changed"
default:
return string(eventType)
}
}
// AuditEventBuilder provides a fluent interface for building audit events
type AuditEventBuilder struct {
event *AuditEvent
}
// NewAuditEventBuilder creates a new audit event builder
func NewAuditEventBuilder(eventType EventType) *AuditEventBuilder {
return &AuditEventBuilder{
event: &AuditEvent{
ID: uuid.New(),
Type: eventType,
Timestamp: time.Now().UTC(),
Severity: SeverityInfo,
Status: StatusSuccess,
Details: make(map[string]interface{}),
Metadata: make(map[string]string),
},
}
}
// WithSeverity sets the event severity
func (b *AuditEventBuilder) WithSeverity(severity EventSeverity) *AuditEventBuilder {
b.event.Severity = severity
return b
}
// WithStatus sets the event status
func (b *AuditEventBuilder) WithStatus(status EventStatus) *AuditEventBuilder {
b.event.Status = status
return b
}
// WithActor sets the actor information
func (b *AuditEventBuilder) WithActor(actorID, actorType, actorIP string) *AuditEventBuilder {
b.event.ActorID = actorID
b.event.ActorType = actorType
b.event.ActorIP = actorIP
return b
}
// WithTenant sets the tenant ID
func (b *AuditEventBuilder) WithTenant(tenantID uuid.UUID) *AuditEventBuilder {
b.event.TenantID = &tenantID
return b
}
// WithResource sets the resource information
func (b *AuditEventBuilder) WithResource(resourceID, resourceType string) *AuditEventBuilder {
b.event.ResourceID = resourceID
b.event.ResourceType = resourceType
return b
}
// WithAction sets the action
func (b *AuditEventBuilder) WithAction(action string) *AuditEventBuilder {
b.event.Action = action
return b
}
// WithDescription sets the description
func (b *AuditEventBuilder) WithDescription(description string) *AuditEventBuilder {
b.event.Description = description
return b
}
// WithDetail adds a detail
func (b *AuditEventBuilder) WithDetail(key string, value interface{}) *AuditEventBuilder {
b.event.Details[key] = value
return b
}
// WithDetails sets multiple details
func (b *AuditEventBuilder) WithDetails(details map[string]interface{}) *AuditEventBuilder {
for k, v := range details {
b.event.Details[k] = v
}
return b
}
// WithRequestContext sets request context information
func (b *AuditEventBuilder) WithRequestContext(requestID, sessionID string) *AuditEventBuilder {
b.event.RequestID = requestID
b.event.SessionID = sessionID
return b
}
// WithTags sets the tags
func (b *AuditEventBuilder) WithTags(tags ...string) *AuditEventBuilder {
b.event.Tags = tags
return b
}
// WithMetadata adds metadata
func (b *AuditEventBuilder) WithMetadata(key, value string) *AuditEventBuilder {
b.event.Metadata[key] = value
return b
}
// Build returns the built audit event
func (b *AuditEventBuilder) Build() *AuditEvent {
return b.event
}

544
internal/auth/saml.go Normal file
View File

@ -0,0 +1,544 @@
package auth
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
)
// SAMLProvider represents a SAML 2.0 identity provider
type SAMLProvider struct {
config config.ConfigProvider
logger *zap.Logger
httpClient *http.Client
privateKey *rsa.PrivateKey
certificate *x509.Certificate
}
// NewSAMLProvider creates a new SAML provider
func NewSAMLProvider(config config.ConfigProvider, logger *zap.Logger) (*SAMLProvider, error) {
provider := &SAMLProvider{
config: config,
logger: logger,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
// Load SP private key and certificate if configured
if err := provider.loadCredentials(); err != nil {
return nil, err
}
return provider, nil
}
// SAMLMetadata represents SAML IdP metadata
type SAMLMetadata struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"`
EntityID string `xml:"entityID,attr"`
IDPSSODescriptor IDPSSODescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata IDPSSODescriptor"`
}
// IDPSSODescriptor represents the IdP SSO descriptor
type IDPSSODescriptor struct {
ProtocolSupportEnumeration string `xml:"protocolSupportEnumeration,attr"`
KeyDescriptor []KeyDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata KeyDescriptor"`
SingleSignOnService []SingleSignOnService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleSignOnService"`
SingleLogoutService []SingleLogoutService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleLogoutService"`
}
// KeyDescriptor represents a key descriptor
type KeyDescriptor struct {
Use string `xml:"use,attr"`
KeyInfo KeyInfo `xml:"urn:xmldsig KeyInfo"`
}
// KeyInfo represents key information
type KeyInfo struct {
X509Data X509Data `xml:"urn:xmldsig X509Data"`
}
// X509Data represents X509 certificate data
type X509Data struct {
X509Certificate string `xml:"urn:xmldsig X509Certificate"`
}
// SingleSignOnService represents SSO service endpoint
type SingleSignOnService struct {
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
}
// SingleLogoutService represents SLO service endpoint
type SingleLogoutService struct {
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
}
// SAMLRequest represents a SAML authentication request
type SAMLRequest struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol AuthnRequest"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Destination string `xml:"Destination,attr"`
AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"`
ProtocolBinding string `xml:"ProtocolBinding,attr"`
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
NameIDPolicy NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
}
// Issuer represents the SAML issuer
type Issuer struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
Value string `xml:",chardata"`
}
// NameIDPolicy represents the name ID policy
type NameIDPolicy struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
Format string `xml:"Format,attr"`
}
// SAMLResponse represents a SAML response
type SAMLResponse struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Destination string `xml:"Destination,attr"`
InResponseTo string `xml:"InResponseTo,attr"`
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
Status Status `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
Assertion Assertion `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
}
// Status represents the SAML response status
type Status struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
StatusCode StatusCode `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
}
// StatusCode represents the status code
type StatusCode struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
Value string `xml:"Value,attr"`
}
// Assertion represents a SAML assertion
type Assertion struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
Subject Subject `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
Conditions Conditions `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
AttributeStatement AttributeStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
AuthnStatement AuthnStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
}
// Subject represents the assertion subject
type Subject struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
NameID NameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
SubjectConfirmation SubjectConfirmation `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
}
// NameID represents the name identifier
type NameID struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
Format string `xml:"Format,attr"`
Value string `xml:",chardata"`
}
// SubjectConfirmation represents subject confirmation
type SubjectConfirmation struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
Method string `xml:"Method,attr"`
SubjectConfirmationData SubjectConfirmationData `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
}
// SubjectConfirmationData represents subject confirmation data
type SubjectConfirmationData struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
InResponseTo string `xml:"InResponseTo,attr"`
NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
Recipient string `xml:"Recipient,attr"`
}
// Conditions represents assertion conditions
type Conditions struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
NotBefore time.Time `xml:"NotBefore,attr"`
NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
AudienceRestriction AudienceRestriction `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
}
// AudienceRestriction represents audience restriction
type AudienceRestriction struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
Audience Audience `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
}
// Audience represents the intended audience
type Audience struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
Value string `xml:",chardata"`
}
// AttributeStatement represents attribute statement
type AttributeStatement struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
Attribute []Attribute `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
}
// Attribute represents a SAML attribute
type Attribute struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
Name string `xml:"Name,attr"`
AttributeValue []AttributeValue `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
}
// AttributeValue represents an attribute value
type AttributeValue struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
Type string `xml:"http://www.w3.org/2001/XMLSchema-instance type,attr"`
Value string `xml:",chardata"`
}
// AuthnStatement represents authentication statement
type AuthnStatement struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
AuthnInstant time.Time `xml:"AuthnInstant,attr"`
SessionIndex string `xml:"SessionIndex,attr"`
AuthnContext AuthnContext `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
}
// AuthnContext represents authentication context
type AuthnContext struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
AuthnContextClassRef string `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContextClassRef"`
}
// GetMetadata fetches the SAML IdP metadata
func (p *SAMLProvider) GetMetadata(ctx context.Context) (*SAMLMetadata, error) {
metadataURL := p.config.GetString("SAML_IDP_METADATA_URL")
if metadataURL == "" {
return nil, errors.NewConfigurationError("SAML_IDP_METADATA_URL not configured")
}
p.logger.Debug("Fetching SAML IdP metadata", zap.String("url", metadataURL))
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
if err != nil {
return nil, errors.NewInternalError("Failed to create metadata request").WithInternal(err)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to fetch IdP metadata").WithInternal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errors.NewInternalError(fmt.Sprintf("Metadata endpoint returned status %d", resp.StatusCode))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read metadata response").WithInternal(err)
}
var metadata SAMLMetadata
if err := xml.Unmarshal(body, &metadata); err != nil {
return nil, errors.NewInternalError("Failed to parse SAML metadata").WithInternal(err)
}
p.logger.Debug("SAML IdP metadata fetched successfully",
zap.String("entity_id", metadata.EntityID))
return &metadata, nil
}
// GenerateAuthRequest generates a SAML authentication request
func (p *SAMLProvider) GenerateAuthRequest(ctx context.Context, relayState string) (string, string, error) {
metadata, err := p.GetMetadata(ctx)
if err != nil {
return "", "", err
}
// Find SSO endpoint
var ssoEndpoint string
for _, sso := range metadata.IDPSSODescriptor.SingleSignOnService {
if sso.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" {
ssoEndpoint = sso.Location
break
}
}
if ssoEndpoint == "" {
return "", "", errors.NewConfigurationError("No HTTP-Redirect SSO endpoint found in IdP metadata")
}
// Generate request ID
requestID := "_" + uuid.New().String()
// Get SP configuration
spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
acsURL := p.config.GetString("SAML_SP_ACS_URL")
if spEntityID == "" {
return "", "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
}
if acsURL == "" {
return "", "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
}
// Create SAML request
samlRequest := SAMLRequest{
ID: requestID,
Version: "2.0",
IssueInstant: time.Now().UTC(),
Destination: ssoEndpoint,
AssertionConsumerServiceURL: acsURL,
ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
Issuer: Issuer{
Value: spEntityID,
},
NameIDPolicy: NameIDPolicy{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress",
},
}
// Marshal to XML
xmlData, err := xml.MarshalIndent(samlRequest, "", " ")
if err != nil {
return "", "", errors.NewInternalError("Failed to marshal SAML request").WithInternal(err)
}
// Add XML declaration
xmlRequest := `<?xml version="1.0" encoding="UTF-8"?>` + "\n" + string(xmlData)
// Base64 encode and URL encode
encodedRequest := base64.StdEncoding.EncodeToString([]byte(xmlRequest))
// Build redirect URL
params := url.Values{
"SAMLRequest": {encodedRequest},
"RelayState": {relayState},
}
redirectURL := ssoEndpoint + "?" + params.Encode()
p.logger.Debug("Generated SAML authentication request",
zap.String("request_id", requestID),
zap.String("sso_endpoint", ssoEndpoint))
return redirectURL, requestID, nil
}
// ProcessSAMLResponse processes a SAML response and extracts user information
func (p *SAMLProvider) ProcessSAMLResponse(ctx context.Context, samlResponse string, expectedRequestID string) (*domain.AuthContext, error) {
p.logger.Debug("Processing SAML response")
// Base64 decode the response
decodedResponse, err := base64.StdEncoding.DecodeString(samlResponse)
if err != nil {
return nil, errors.NewValidationError("Failed to decode SAML response").WithInternal(err)
}
// Parse XML
var response SAMLResponse
if err := xml.Unmarshal(decodedResponse, &response); err != nil {
return nil, errors.NewValidationError("Failed to parse SAML response").WithInternal(err)
}
// Validate response
if err := p.validateSAMLResponse(&response, expectedRequestID); err != nil {
return nil, err
}
// Extract user information from assertion
authContext, err := p.extractUserInfo(&response.Assertion)
if err != nil {
return nil, err
}
p.logger.Debug("SAML response processed successfully",
zap.String("user_id", authContext.UserID))
return authContext, nil
}
// validateSAMLResponse validates a SAML response
func (p *SAMLProvider) validateSAMLResponse(response *SAMLResponse, expectedRequestID string) error {
// Check status
if response.Status.StatusCode.Value != "urn:oasis:names:tc:SAML:2.0:status:Success" {
return errors.NewAuthenticationError("SAML authentication failed: " + response.Status.StatusCode.Value)
}
// Validate InResponseTo
if expectedRequestID != "" && response.InResponseTo != expectedRequestID {
return errors.NewValidationError("SAML response InResponseTo does not match request ID")
}
// Validate assertion conditions
assertion := &response.Assertion
now := time.Now().UTC()
if now.Before(assertion.Conditions.NotBefore) {
return errors.NewValidationError("SAML assertion not yet valid")
}
if now.After(assertion.Conditions.NotOnOrAfter) {
return errors.NewValidationError("SAML assertion has expired")
}
// Validate audience
expectedAudience := p.config.GetString("SAML_SP_ENTITY_ID")
if assertion.Conditions.AudienceRestriction.Audience.Value != expectedAudience {
return errors.NewValidationError("SAML assertion audience mismatch")
}
// In production, you should also validate the signature
// This requires implementing XML signature validation
return nil
}
// extractUserInfo extracts user information from SAML assertion
func (p *SAMLProvider) extractUserInfo(assertion *Assertion) (*domain.AuthContext, error) {
// Extract user ID from NameID
userID := assertion.Subject.NameID.Value
if userID == "" {
return nil, errors.NewValidationError("SAML assertion missing NameID")
}
// Extract attributes
claims := make(map[string]string)
claims["sub"] = userID
claims["name_id_format"] = assertion.Subject.NameID.Format
// Process attribute statements
for _, attr := range assertion.AttributeStatement.Attribute {
if len(attr.AttributeValue) > 0 {
// Use the first value if multiple values exist
claims[attr.Name] = attr.AttributeValue[0].Value
}
}
// Map common attributes to standard claims
if email, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"]; exists {
claims["email"] = email
}
if name, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"]; exists {
claims["name"] = name
}
if givenName, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname"]; exists {
claims["given_name"] = givenName
}
if surname, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname"]; exists {
claims["family_name"] = surname
}
// Extract permissions/roles if available
var permissions []string
if roles, exists := claims["http://schemas.microsoft.com/ws/2008/06/identity/claims/role"]; exists {
permissions = strings.Split(roles, ",")
}
authContext := &domain.AuthContext{
UserID: userID,
TokenType: domain.TokenTypeUser,
Claims: claims,
Permissions: permissions,
}
return authContext, nil
}
// GenerateServiceProviderMetadata generates SP metadata XML
func (p *SAMLProvider) GenerateServiceProviderMetadata() (string, error) {
spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
acsURL := p.config.GetString("SAML_SP_ACS_URL")
if spEntityID == "" {
return "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
}
if acsURL == "" {
return "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
}
// This is a simplified SP metadata generation
// In production, you should use a proper SAML library
metadata := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="%s">
<md:SPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="%s" index="0"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>`, spEntityID, acsURL)
return metadata, nil
}
// loadCredentials loads SP private key and certificate
func (p *SAMLProvider) loadCredentials() error {
// Load private key if configured
privateKeyPEM := p.config.GetString("SAML_SP_PRIVATE_KEY")
if privateKeyPEM != "" {
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return errors.NewConfigurationError("Failed to decode SAML SP private key")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
// Try PKCS8 format
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return errors.NewConfigurationError("Failed to parse SAML SP private key").WithInternal(err)
}
var ok bool
privateKey, ok = key.(*rsa.PrivateKey)
if !ok {
return errors.NewConfigurationError("SAML SP private key is not RSA")
}
}
p.privateKey = privateKey
}
// Load certificate if configured
certificatePEM := p.config.GetString("SAML_SP_CERTIFICATE")
if certificatePEM != "" {
block, _ := pem.Decode([]byte(certificatePEM))
if block == nil {
return errors.NewConfigurationError("Failed to decode SAML SP certificate")
}
certificate, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return errors.NewConfigurationError("Failed to parse SAML SP certificate").WithInternal(err)
}
p.certificate = certificate
}
return nil
}

View File

@ -132,6 +132,12 @@ func (c *Config) setDefaults() {
"IP_BLOCK_DURATION": "1h",
"REQUEST_MAX_AGE": "5m",
"IP_WHITELIST": "",
"SAML_ENABLED": "false",
"SAML_IDP_METADATA_URL": "",
"SAML_SP_ENTITY_ID": "",
"SAML_SP_ACS_URL": "",
"SAML_SP_PRIVATE_KEY": "",
"SAML_SP_CERTIFICATE": "",
}
for key, value := range defaults {

View File

@ -188,6 +188,23 @@ type CreateStaticTokenResponse struct {
CreatedAt time.Time `json:"created_at"`
}
// CreateTokenRequest represents a request to create a token
type CreateTokenRequest struct {
AppID string `json:"app_id" validate:"required"`
Type TokenType `json:"type" validate:"required,oneof=static user"`
UserID string `json:"user_id,omitempty"` // Required for user tokens
Permissions []string `json:"permissions,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// CreateTokenResponse represents a response for creating a token
type CreateTokenResponse struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
TokenType TokenType `json:"token_type"`
}
// AuthContext represents the authentication context for a request
type AuthContext struct {
UserID string `json:"user_id"`
@ -196,3 +213,25 @@ type AuthContext struct {
Claims map[string]string `json:"claims"`
AppID string `json:"app_id"`
}
// TokenResponse represents the OAuth2 token response
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// UserInfo represents user information from the OAuth2/OIDC provider
type UserInfo struct {
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Picture string `json:"picture"`
PreferredUsername string `json:"preferred_username"`
}

153
internal/domain/session.go Normal file
View File

@ -0,0 +1,153 @@
package domain
import (
"time"
"github.com/google/uuid"
)
// SessionStatus represents the status of a user session
type SessionStatus string
const (
SessionStatusActive SessionStatus = "active"
SessionStatusExpired SessionStatus = "expired"
SessionStatusRevoked SessionStatus = "revoked"
SessionStatusSuspended SessionStatus = "suspended"
)
// SessionType represents the type of session
type SessionType string
const (
SessionTypeWeb SessionType = "web"
SessionTypeMobile SessionType = "mobile"
SessionTypeAPI SessionType = "api"
)
// UserSession represents a user session in the system
type UserSession struct {
ID uuid.UUID `json:"id" db:"id"`
UserID string `json:"user_id" validate:"required" db:"user_id"`
AppID string `json:"app_id" validate:"required" db:"app_id"`
SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api" db:"session_type"`
Status SessionStatus `json:"status" validate:"required,oneof=active expired revoked suspended" db:"status"`
AccessToken string `json:"-" db:"access_token"` // Hidden from JSON for security
RefreshToken string `json:"-" db:"refresh_token"` // Hidden from JSON for security
IDToken string `json:"-" db:"id_token"` // Hidden from JSON for security
IPAddress string `json:"ip_address" db:"ip_address"`
UserAgent string `json:"user_agent" db:"user_agent"`
LastActivity time.Time `json:"last_activity" db:"last_activity"`
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
RevokedAt *time.Time `json:"revoked_at,omitempty" db:"revoked_at"`
RevokedBy *string `json:"revoked_by,omitempty" db:"revoked_by"`
Metadata SessionMetadata `json:"metadata" db:"metadata"`
}
// SessionMetadata contains additional session information
type SessionMetadata struct {
DeviceInfo string `json:"device_info,omitempty"`
Location string `json:"location,omitempty"`
LoginMethod string `json:"login_method,omitempty"`
TenantID string `json:"tenant_id,omitempty"`
Permissions []string `json:"permissions,omitempty"`
Claims map[string]string `json:"claims,omitempty"`
RefreshCount int `json:"refresh_count"`
LastRefresh *time.Time `json:"last_refresh,omitempty"`
}
// CreateSessionRequest represents a request to create a new session
type CreateSessionRequest struct {
UserID string `json:"user_id" validate:"required"`
AppID string `json:"app_id" validate:"required"`
SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api"`
IPAddress string `json:"ip_address" validate:"required,ip"`
UserAgent string `json:"user_agent" validate:"required"`
ExpiresAt time.Time `json:"expires_at" validate:"required"`
Permissions []string `json:"permissions,omitempty"`
Claims map[string]string `json:"claims,omitempty"`
TenantID string `json:"tenant_id,omitempty"`
}
// UpdateSessionRequest represents a request to update a session
type UpdateSessionRequest struct {
Status *SessionStatus `json:"status,omitempty" validate:"omitempty,oneof=active expired revoked suspended"`
LastActivity *time.Time `json:"last_activity,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IPAddress *string `json:"ip_address,omitempty" validate:"omitempty,ip"`
UserAgent *string `json:"user_agent,omitempty"`
}
// SessionListRequest represents a request to list sessions
type SessionListRequest struct {
UserID string `json:"user_id,omitempty"`
AppID string `json:"app_id,omitempty"`
Status *SessionStatus `json:"status,omitempty"`
SessionType *SessionType `json:"session_type,omitempty"`
TenantID string `json:"tenant_id,omitempty"`
Limit int `json:"limit" validate:"min=1,max=100"`
Offset int `json:"offset" validate:"min=0"`
}
// SessionListResponse represents a response for listing sessions
type SessionListResponse struct {
Sessions []*UserSession `json:"sessions"`
Total int `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
// IsActive checks if the session is currently active
func (s *UserSession) IsActive() bool {
return s.Status == SessionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired checks if the session has expired
func (s *UserSession) IsExpired() bool {
return time.Now().After(s.ExpiresAt) || s.Status == SessionStatusExpired
}
// IsRevoked checks if the session has been revoked
func (s *UserSession) IsRevoked() bool {
return s.Status == SessionStatusRevoked
}
// CanRefresh checks if the session can be refreshed
func (s *UserSession) CanRefresh() bool {
return s.IsActive() && s.RefreshToken != ""
}
// UpdateActivity updates the last activity timestamp
func (s *UserSession) UpdateActivity() {
s.LastActivity = time.Now()
s.UpdatedAt = time.Now()
}
// Revoke marks the session as revoked
func (s *UserSession) Revoke(revokedBy string) {
now := time.Now()
s.Status = SessionStatusRevoked
s.RevokedAt = &now
s.RevokedBy = &revokedBy
s.UpdatedAt = now
}
// Expire marks the session as expired
func (s *UserSession) Expire() {
s.Status = SessionStatusExpired
s.UpdatedAt = time.Now()
}
// Suspend marks the session as suspended
func (s *UserSession) Suspend() {
s.Status = SessionStatusSuspended
s.UpdatedAt = time.Now()
}
// Activate marks the session as active
func (s *UserSession) Activate() {
s.Status = SessionStatusActive
s.UpdatedAt = time.Now()
}

307
internal/domain/tenant.go Normal file
View File

@ -0,0 +1,307 @@
package domain
import (
"time"
"github.com/google/uuid"
)
// TenantStatus represents the status of a tenant
type TenantStatus string
const (
TenantStatusActive TenantStatus = "active"
TenantStatusSuspended TenantStatus = "suspended"
TenantStatusInactive TenantStatus = "inactive"
)
// Tenant represents a tenant in the multi-tenant system
type Tenant struct {
ID uuid.UUID `json:"id" db:"id"`
Name string `json:"name" validate:"required,min=1,max=255" db:"name"`
Slug string `json:"slug" validate:"required,min=1,max=100,alphanum" db:"slug"`
Status TenantStatus `json:"status" validate:"required,oneof=active suspended inactive" db:"status"`
Domain string `json:"domain,omitempty" validate:"omitempty,fqdn" db:"domain"`
Description string `json:"description,omitempty" validate:"max=1000" db:"description"`
Settings TenantSettings `json:"settings" db:"settings"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
CreatedBy string `json:"created_by" db:"created_by"`
UpdatedBy string `json:"updated_by" db:"updated_by"`
}
// TenantSettings contains tenant-specific configuration
type TenantSettings struct {
// Authentication settings
AuthProvider string `json:"auth_provider,omitempty"` // oauth2, saml, header
SAMLSettings *SAMLSettings `json:"saml_settings,omitempty"`
OAuth2Settings *OAuth2Settings `json:"oauth2_settings,omitempty"`
// Session settings
SessionTimeout time.Duration `json:"session_timeout,omitempty"`
MaxConcurrentSessions int `json:"max_concurrent_sessions,omitempty"`
// Security settings
RequireMFA bool `json:"require_mfa"`
AllowedIPRanges []string `json:"allowed_ip_ranges,omitempty"`
PasswordPolicy *PasswordPolicy `json:"password_policy,omitempty"`
// Token settings
DefaultTokenDuration time.Duration `json:"default_token_duration,omitempty"`
MaxTokenDuration time.Duration `json:"max_token_duration,omitempty"`
// Feature flags
Features map[string]bool `json:"features,omitempty"`
// Custom attributes
CustomAttributes map[string]string `json:"custom_attributes,omitempty"`
}
// SAMLSettings contains SAML-specific configuration for a tenant
type SAMLSettings struct {
IDPMetadataURL string `json:"idp_metadata_url,omitempty"`
SPEntityID string `json:"sp_entity_id,omitempty"`
ACSURL string `json:"acs_url,omitempty"`
SPPrivateKey string `json:"sp_private_key,omitempty"`
SPCertificate string `json:"sp_certificate,omitempty"`
AttributeMapping map[string]string `json:"attribute_mapping,omitempty"`
}
// OAuth2Settings contains OAuth2-specific configuration for a tenant
type OAuth2Settings struct {
ProviderURL string `json:"provider_url,omitempty"`
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
Scopes []string `json:"scopes,omitempty"`
AttributeMapping map[string]string `json:"attribute_mapping,omitempty"`
}
// PasswordPolicy defines password requirements for a tenant
type PasswordPolicy struct {
MinLength int `json:"min_length"`
RequireUppercase bool `json:"require_uppercase"`
RequireLowercase bool `json:"require_lowercase"`
RequireNumbers bool `json:"require_numbers"`
RequireSymbols bool `json:"require_symbols"`
MaxAge time.Duration `json:"max_age,omitempty"`
PreventReuse int `json:"prevent_reuse"` // Number of previous passwords to prevent reuse
}
// TenantUser represents a user within a specific tenant
type TenantUser struct {
ID uuid.UUID `json:"id" db:"id"`
TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"`
UserID string `json:"user_id" validate:"required" db:"user_id"`
Email string `json:"email" validate:"required,email" db:"email"`
Name string `json:"name" validate:"required" db:"name"`
Roles []string `json:"roles" db:"roles"`
Permissions []string `json:"permissions" db:"permissions"`
Status UserStatus `json:"status" validate:"required,oneof=active inactive suspended" db:"status"`
Metadata map[string]string `json:"metadata,omitempty" db:"metadata"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
LastLoginAt *time.Time `json:"last_login_at,omitempty" db:"last_login_at"`
}
// UserStatus represents the status of a user within a tenant
type UserStatus string
const (
UserStatusActive UserStatus = "active"
UserStatusInactive UserStatus = "inactive"
UserStatusSuspended UserStatus = "suspended"
)
// TenantRole represents a role within a tenant
type TenantRole struct {
ID uuid.UUID `json:"id" db:"id"`
TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"`
Name string `json:"name" validate:"required,min=1,max=100" db:"name"`
Description string `json:"description,omitempty" validate:"max=500" db:"description"`
Permissions []string `json:"permissions" db:"permissions"`
IsSystem bool `json:"is_system" db:"is_system"` // System roles cannot be deleted
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
CreatedBy string `json:"created_by" db:"created_by"`
UpdatedBy string `json:"updated_by" db:"updated_by"`
}
// CreateTenantRequest represents a request to create a new tenant
type CreateTenantRequest struct {
Name string `json:"name" validate:"required,min=1,max=255"`
Slug string `json:"slug" validate:"required,min=1,max=100,alphanum"`
Domain string `json:"domain,omitempty" validate:"omitempty,fqdn"`
Description string `json:"description,omitempty" validate:"max=1000"`
Settings TenantSettings `json:"settings,omitempty"`
}
// UpdateTenantRequest represents a request to update a tenant
type UpdateTenantRequest struct {
Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=255"`
Status *TenantStatus `json:"status,omitempty" validate:"omitempty,oneof=active suspended inactive"`
Domain *string `json:"domain,omitempty" validate:"omitempty,fqdn"`
Description *string `json:"description,omitempty" validate:"omitempty,max=1000"`
Settings *TenantSettings `json:"settings,omitempty"`
}
// CreateTenantUserRequest represents a request to create a user in a tenant
type CreateTenantUserRequest struct {
TenantID uuid.UUID `json:"tenant_id" validate:"required"`
UserID string `json:"user_id" validate:"required"`
Email string `json:"email" validate:"required,email"`
Name string `json:"name" validate:"required"`
Roles []string `json:"roles,omitempty"`
Permissions []string `json:"permissions,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// UpdateTenantUserRequest represents a request to update a tenant user
type UpdateTenantUserRequest struct {
Email *string `json:"email,omitempty" validate:"omitempty,email"`
Name *string `json:"name,omitempty" validate:"omitempty,min=1"`
Roles []string `json:"roles,omitempty"`
Permissions []string `json:"permissions,omitempty"`
Status *UserStatus `json:"status,omitempty" validate:"omitempty,oneof=active inactive suspended"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// CreateTenantRoleRequest represents a request to create a role in a tenant
type CreateTenantRoleRequest struct {
TenantID uuid.UUID `json:"tenant_id" validate:"required"`
Name string `json:"name" validate:"required,min=1,max=100"`
Description string `json:"description,omitempty" validate:"max=500"`
Permissions []string `json:"permissions,omitempty"`
}
// UpdateTenantRoleRequest represents a request to update a tenant role
type UpdateTenantRoleRequest struct {
Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=100"`
Description *string `json:"description,omitempty" validate:"omitempty,max=500"`
Permissions []string `json:"permissions,omitempty"`
}
// TenantListRequest represents a request to list tenants
type TenantListRequest struct {
Status *TenantStatus `json:"status,omitempty"`
Domain string `json:"domain,omitempty"`
Limit int `json:"limit" validate:"min=1,max=100"`
Offset int `json:"offset" validate:"min=0"`
}
// TenantListResponse represents a response for listing tenants
type TenantListResponse struct {
Tenants []*Tenant `json:"tenants"`
Total int `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
// IsActive checks if the tenant is active
func (t *Tenant) IsActive() bool {
return t.Status == TenantStatusActive
}
// IsSuspended checks if the tenant is suspended
func (t *Tenant) IsSuspended() bool {
return t.Status == TenantStatusSuspended
}
// HasFeature checks if a feature is enabled for the tenant
func (t *Tenant) HasFeature(feature string) bool {
if t.Settings.Features == nil {
return false
}
enabled, exists := t.Settings.Features[feature]
return exists && enabled
}
// GetAuthProvider returns the authentication provider for the tenant
func (t *Tenant) GetAuthProvider() string {
if t.Settings.AuthProvider != "" {
return t.Settings.AuthProvider
}
return "header" // default
}
// GetSessionTimeout returns the session timeout for the tenant
func (t *Tenant) GetSessionTimeout() time.Duration {
if t.Settings.SessionTimeout > 0 {
return t.Settings.SessionTimeout
}
return 8 * time.Hour // default
}
// GetMaxConcurrentSessions returns the maximum concurrent sessions for the tenant
func (t *Tenant) GetMaxConcurrentSessions() int {
if t.Settings.MaxConcurrentSessions > 0 {
return t.Settings.MaxConcurrentSessions
}
return 10 // default
}
// IsActive checks if the tenant user is active
func (tu *TenantUser) IsActive() bool {
return tu.Status == UserStatusActive
}
// IsSuspended checks if the tenant user is suspended
func (tu *TenantUser) IsSuspended() bool {
return tu.Status == UserStatusSuspended
}
// HasRole checks if the user has a specific role
func (tu *TenantUser) HasRole(role string) bool {
for _, r := range tu.Roles {
if r == role {
return true
}
}
return false
}
// HasPermission checks if the user has a specific permission
func (tu *TenantUser) HasPermission(permission string) bool {
for _, p := range tu.Permissions {
if p == permission {
return true
}
}
return false
}
// UpdateLastLogin updates the last login timestamp
func (tu *TenantUser) UpdateLastLogin() {
now := time.Now()
tu.LastLoginAt = &now
tu.UpdatedAt = now
}
// IsSystemRole checks if the role is a system role
func (tr *TenantRole) IsSystemRole() bool {
return tr.IsSystem
}
// HasPermission checks if the role has a specific permission
func (tr *TenantRole) HasPermission(permission string) bool {
for _, p := range tr.Permissions {
if p == permission {
return true
}
}
return false
}
// TenantContext represents the tenant context for a request
type TenantContext struct {
TenantID uuid.UUID `json:"tenant_id"`
TenantSlug string `json:"tenant_slug"`
UserID string `json:"user_id"`
Roles []string `json:"roles"`
Permissions []string `json:"permissions"`
}
// MultiTenantAuthContext extends AuthContext with tenant information
type MultiTenantAuthContext struct {
*AuthContext
TenantContext *TenantContext `json:"tenant_context,omitempty"`
}

View File

@ -295,3 +295,66 @@ func (c *Chain) Error() string {
}
return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1)
}
// Helper functions to check error types
// IsNotFound checks if an error is a not found error
func IsNotFound(err error) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == ErrNotFound || appErr.Code == ErrApplicationNotFound ||
appErr.Code == ErrTokenNotFound || appErr.Code == ErrPermissionNotFound
}
return false
}
// IsValidationError checks if an error is a validation error
func IsValidationError(err error) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == ErrValidationFailed || appErr.Code == ErrInvalidInput ||
appErr.Code == ErrMissingField || appErr.Code == ErrInvalidFormat
}
return false
}
// IsAuthenticationError checks if an error is an authentication error
func IsAuthenticationError(err error) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == ErrUnauthorized || appErr.Code == ErrInvalidToken ||
appErr.Code == ErrTokenExpired || appErr.Code == ErrInvalidCredentials
}
return false
}
// IsAuthorizationError checks if an error is an authorization error
func IsAuthorizationError(err error) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == ErrForbidden || appErr.Code == ErrInsufficientPermissions
}
return false
}
// IsConflictError checks if an error is a conflict error
func IsConflictError(err error) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == ErrAlreadyExists || appErr.Code == ErrConflict
}
return false
}
// IsInternalError checks if an error is an internal server error
func IsInternalError(err error) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == ErrInternal || appErr.Code == ErrDatabase ||
appErr.Code == ErrExternal || appErr.Code == ErrTokenCreationFailed
}
return false
}
// IsConfigurationError checks if an error is a configuration error
func IsConfigurationError(err error) bool {
if appErr, ok := err.(*AppError); ok {
// Configuration errors are typically mapped to internal errors
return appErr.Code == ErrInternal && appErr.Message != ""
}
return false
}

352
internal/handlers/saml.go Normal file
View File

@ -0,0 +1,352 @@
package handlers
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/services"
)
// SAMLHandler handles SAML authentication endpoints
type SAMLHandler struct {
samlProvider *auth.SAMLProvider
sessionService services.SessionService
authService services.AuthenticationService
tokenService services.TokenService
config config.ConfigProvider
logger *zap.Logger
}
// NewSAMLHandler creates a new SAML handler
func NewSAMLHandler(
config config.ConfigProvider,
sessionService services.SessionService,
authService services.AuthenticationService,
tokenService services.TokenService,
logger *zap.Logger,
) (*SAMLHandler, error) {
samlProvider, err := auth.NewSAMLProvider(config, logger)
if err != nil {
return nil, err
}
return &SAMLHandler{
samlProvider: samlProvider,
sessionService: sessionService,
authService: authService,
config: config,
logger: logger,
}, nil
}
// RegisterRoutes registers SAML routes
func (h *SAMLHandler) RegisterRoutes(router *mux.Router) {
// SAML endpoints
router.HandleFunc("/auth/saml/login", h.InitiateSAMLLogin).Methods("GET")
router.HandleFunc("/auth/saml/acs", h.HandleSAMLResponse).Methods("POST")
router.HandleFunc("/auth/saml/metadata", h.GetServiceProviderMetadata).Methods("GET")
router.HandleFunc("/auth/saml/slo", h.HandleSingleLogout).Methods("GET", "POST")
}
// InitiateSAMLLogin initiates SAML authentication
func (h *SAMLHandler) InitiateSAMLLogin(w http.ResponseWriter, r *http.Request) {
if !h.config.GetBool("SAML_ENABLED") {
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
return
}
// Get query parameters
appID := r.URL.Query().Get("app_id")
redirectURL := r.URL.Query().Get("redirect_url")
if appID == "" {
h.writeErrorResponse(w, errors.NewValidationError("app_id parameter is required"))
return
}
// Generate relay state with app_id and redirect_url
relayState := appID
if redirectURL != "" {
relayState += "|" + redirectURL
}
h.logger.Debug("Initiating SAML login",
zap.String("app_id", appID),
zap.String("redirect_url", redirectURL))
// Generate SAML authentication request
authURL, requestID, err := h.samlProvider.GenerateAuthRequest(r.Context(), relayState)
if err != nil {
h.logger.Error("Failed to generate SAML auth request", zap.Error(err))
h.writeErrorResponse(w, err)
return
}
// Store request ID in session/cache for validation
// In production, you should store this securely
h.logger.Debug("Generated SAML auth request",
zap.String("request_id", requestID),
zap.String("auth_url", authURL))
// Redirect to IdP
http.Redirect(w, r, authURL, http.StatusFound)
}
// HandleSAMLResponse handles SAML assertion consumer service (ACS)
func (h *SAMLHandler) HandleSAMLResponse(w http.ResponseWriter, r *http.Request) {
if !h.config.GetBool("SAML_ENABLED") {
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
return
}
h.logger.Debug("Handling SAML response")
// Parse form data
if err := r.ParseForm(); err != nil {
h.writeErrorResponse(w, errors.NewValidationError("Failed to parse form data").WithInternal(err))
return
}
samlResponse := r.FormValue("SAMLResponse")
relayState := r.FormValue("RelayState")
if samlResponse == "" {
h.writeErrorResponse(w, errors.NewValidationError("SAMLResponse is required"))
return
}
h.logger.Debug("Processing SAML response", zap.String("relay_state", relayState))
// Process SAML response
// In production, you should retrieve and validate the original request ID
authContext, err := h.samlProvider.ProcessSAMLResponse(r.Context(), samlResponse, "")
if err != nil {
h.logger.Error("Failed to process SAML response", zap.Error(err))
h.writeErrorResponse(w, err)
return
}
// Parse relay state to get app_id and redirect_url
appID, redirectURL := h.parseRelayState(relayState)
if appID == "" {
h.writeErrorResponse(w, errors.NewValidationError("Invalid relay state: missing app_id"))
return
}
// Create user session
sessionReq := &domain.CreateSessionRequest{
UserID: authContext.UserID,
AppID: appID,
SessionType: domain.SessionTypeWeb,
IPAddress: h.getClientIP(r),
UserAgent: r.UserAgent(),
ExpiresAt: time.Now().Add(8 * time.Hour), // 8 hour session
Permissions: authContext.Permissions,
Claims: authContext.Claims,
}
session, err := h.sessionService.CreateSession(r.Context(), sessionReq)
if err != nil {
h.logger.Error("Failed to create session", zap.Error(err))
h.writeErrorResponse(w, err)
return
}
// Generate JWT token for the session using the existing token service
userToken := &domain.UserToken{
AppID: appID,
UserID: authContext.UserID,
Permissions: authContext.Permissions,
IssuedAt: time.Now(),
ExpiresAt: session.ExpiresAt,
MaxValidAt: session.ExpiresAt,
TokenType: domain.TokenTypeUser,
Claims: authContext.Claims,
}
tokenString, err := h.authService.GenerateJWTToken(r.Context(), userToken)
if err != nil {
h.logger.Error("Failed to create JWT token", zap.Error(err))
h.writeErrorResponse(w, err)
return
}
h.logger.Debug("SAML authentication successful",
zap.String("user_id", authContext.UserID),
zap.String("session_id", session.ID.String()))
// If redirect URL is provided, redirect with token
if redirectURL != "" {
// Add token as query parameter or fragment
redirectURL += "?token=" + tokenString
http.Redirect(w, r, redirectURL, http.StatusFound)
return
}
// Otherwise, return JSON response
response := map[string]interface{}{
"success": true,
"token": tokenString,
"user": map[string]interface{}{
"id": authContext.UserID,
"email": authContext.Claims["email"],
"name": authContext.Claims["name"],
},
"session_id": session.ID.String(),
"expires_at": session.ExpiresAt,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// GetServiceProviderMetadata returns SP metadata XML
func (h *SAMLHandler) GetServiceProviderMetadata(w http.ResponseWriter, r *http.Request) {
if !h.config.GetBool("SAML_ENABLED") {
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
return
}
h.logger.Debug("Generating SP metadata")
metadata, err := h.samlProvider.GenerateServiceProviderMetadata()
if err != nil {
h.logger.Error("Failed to generate SP metadata", zap.Error(err))
h.writeErrorResponse(w, err)
return
}
w.Header().Set("Content-Type", "application/xml")
w.Write([]byte(metadata))
}
// HandleSingleLogout handles SAML single logout
func (h *SAMLHandler) HandleSingleLogout(w http.ResponseWriter, r *http.Request) {
if !h.config.GetBool("SAML_ENABLED") {
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
return
}
h.logger.Debug("Handling SAML single logout")
// Get session ID from query parameter or form
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" && r.Method == "POST" {
r.ParseForm()
sessionID = r.FormValue("session_id")
}
if sessionID != "" {
// Revoke specific session
h.logger.Debug("Revoking session", zap.String("session_id", sessionID))
// Implementation would depend on how you store session IDs
// For now, we'll just log it
}
// In a full implementation, you would:
// 1. Parse the SAML LogoutRequest
// 2. Validate the request
// 3. Revoke the user's sessions
// 4. Generate a LogoutResponse
// 5. Redirect back to the IdP
// For now, return a simple success response
response := map[string]interface{}{
"success": true,
"message": "Logout successful",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// parseRelayState parses the relay state to extract app_id and redirect_url
func (h *SAMLHandler) parseRelayState(relayState string) (appID, redirectURL string) {
if relayState == "" {
return "", ""
}
// RelayState format: "app_id|redirect_url" or just "app_id"
parts := []string{relayState}
if len(relayState) > 0 && relayState[0] != '|' {
// Split on first pipe character
for i, char := range relayState {
if char == '|' {
parts = []string{relayState[:i], relayState[i+1:]}
break
}
}
}
appID = parts[0]
if len(parts) > 1 {
redirectURL = parts[1]
}
return appID, redirectURL
}
// getClientIP extracts the client IP address from the request
func (h *SAMLHandler) getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP if multiple are present
if idx := len(xff); idx > 0 {
for i, char := range xff {
if char == ',' {
return xff[:i]
}
}
return xff
}
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
return r.RemoteAddr
}
// writeErrorResponse writes an error response
func (h *SAMLHandler) writeErrorResponse(w http.ResponseWriter, err error) {
var statusCode int
var errorCode string
switch {
case errors.IsValidationError(err):
statusCode = http.StatusBadRequest
errorCode = "VALIDATION_ERROR"
case errors.IsAuthenticationError(err):
statusCode = http.StatusUnauthorized
errorCode = "AUTHENTICATION_ERROR"
case errors.IsConfigurationError(err):
statusCode = http.StatusServiceUnavailable
errorCode = "CONFIGURATION_ERROR"
default:
statusCode = http.StatusInternalServerError
errorCode = "INTERNAL_ERROR"
}
response := map[string]interface{}{
"success": false,
"error": map[string]interface{}{
"code": errorCode,
"message": err.Error(),
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(response)
}

View File

@ -2,7 +2,6 @@ package middleware
import (
"context"
"fmt"
"net"
"net/http"
"strings"
@ -14,7 +13,6 @@ import (
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/errors"
)
// SecurityMiddleware provides various security features
@ -408,8 +406,6 @@ func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
// GetSecurityMetrics returns security-related metrics
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
ctx := context.Background()
// This is a simplified version - in production you'd want more comprehensive metrics
metrics := map[string]interface{}{
"active_rate_limiters": len(s.rateLimiters),

View File

@ -104,6 +104,57 @@ type GrantedPermissionRepository interface {
HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error)
}
// SessionRepository defines the interface for user session data operations
type SessionRepository interface {
// Create creates a new user session
Create(ctx context.Context, session *domain.UserSession) error
// GetByID retrieves a session by its ID
GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
// GetByUserID retrieves all sessions for a user
GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
// GetByUserAndApp retrieves sessions for a specific user and application
GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
// GetActiveByUserID retrieves all active sessions for a user
GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
// List retrieves sessions with filtering and pagination
List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
// Update updates an existing session
Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
// UpdateActivity updates the last activity timestamp for a session
UpdateActivity(ctx context.Context, sessionID uuid.UUID) error
// Revoke revokes a session
Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
// RevokeAllByUser revokes all sessions for a user
RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error
// RevokeAllByUserAndApp revokes all sessions for a user and application
RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error
// ExpireOldSessions marks expired sessions as expired
ExpireOldSessions(ctx context.Context) (int, error)
// DeleteExpiredSessions removes expired sessions older than the specified duration
DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error)
// Exists checks if a session exists
Exists(ctx context.Context, sessionID uuid.UUID) (bool, error)
// GetSessionCount returns the total number of sessions for a user
GetSessionCount(ctx context.Context, userID string) (int, error)
// GetActiveSessionCount returns the number of active sessions for a user
GetActiveSessionCount(ctx context.Context, userID string) (int, error)
}
// DatabaseProvider defines the interface for database operations
type DatabaseProvider interface {
// GetDB returns the underlying database connection

View File

@ -0,0 +1,624 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/repository"
)
// sessionRepository implements the SessionRepository interface
type sessionRepository struct {
db *sqlx.DB
logger *zap.Logger
}
// NewSessionRepository creates a new session repository
func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository {
return &sessionRepository{
db: db,
logger: logger,
}
}
// Create creates a new user session
func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error {
r.logger.Debug("Creating new session",
zap.String("user_id", session.UserID),
zap.String("app_id", session.AppID),
zap.String("session_type", string(session.SessionType)))
// Generate ID if not provided
if session.ID == uuid.Nil {
session.ID = uuid.New()
}
// Set timestamps
now := time.Now()
session.CreatedAt = now
session.UpdatedAt = now
session.LastActivity = now
// Serialize metadata
metadataJSON, err := json.Marshal(session.Metadata)
if err != nil {
return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err)
}
query := `
INSERT INTO user_sessions (
id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at, metadata
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
)`
_, err = r.db.ExecContext(ctx, query,
session.ID,
session.UserID,
session.AppID,
session.SessionType,
session.Status,
session.AccessToken,
session.RefreshToken,
session.IDToken,
session.IPAddress,
session.UserAgent,
session.LastActivity,
session.ExpiresAt,
session.CreatedAt,
session.UpdatedAt,
metadataJSON,
)
if err != nil {
r.logger.Error("Failed to create session", zap.Error(err))
return errors.NewInternalError("Failed to create session").WithInternal(err)
}
r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
return nil
}
// GetByID retrieves a session by its ID
func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String()))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE id = $1`
var session domain.UserSession
var metadataJSON []byte
var revokedAt sql.NullTime
var revokedBy sql.NullString
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(
&session.ID,
&session.UserID,
&session.AppID,
&session.SessionType,
&session.Status,
&session.AccessToken,
&session.RefreshToken,
&session.IDToken,
&session.IPAddress,
&session.UserAgent,
&session.LastActivity,
&session.ExpiresAt,
&session.CreatedAt,
&session.UpdatedAt,
&revokedAt,
&revokedBy,
&metadataJSON,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.NewNotFoundError("Session not found")
}
r.logger.Error("Failed to get session by ID", zap.Error(err))
return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err)
}
// Handle nullable fields
if revokedAt.Valid {
session.RevokedAt = &revokedAt.Time
}
if revokedBy.Valid {
session.RevokedBy = &revokedBy.String
}
// Deserialize metadata
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
}
r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String()))
return &session, nil
}
// GetByUserID retrieves all sessions for a user
func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1
ORDER BY created_at DESC`
return r.scanSessions(ctx, query, userID)
}
// GetByUserAndApp retrieves sessions for a specific user and application
func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting sessions by user and app",
zap.String("user_id", userID),
zap.String("app_id", appID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1 AND app_id = $2
ORDER BY created_at DESC`
return r.scanSessions(ctx, query, userID, appID)
}
// GetActiveByUserID retrieves all active sessions for a user
func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1 AND status = $2 AND expires_at > NOW()
ORDER BY last_activity DESC`
return r.scanSessions(ctx, query, userID, domain.SessionStatusActive)
}
// List retrieves sessions with filtering and pagination
func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
r.logger.Debug("Listing sessions with filters",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.Int("limit", req.Limit),
zap.Int("offset", req.Offset))
// Build WHERE clause dynamically
whereClause := "WHERE 1=1"
args := []interface{}{}
argIndex := 1
if req.UserID != "" {
whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex)
args = append(args, req.UserID)
argIndex++
}
if req.AppID != "" {
whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex)
args = append(args, req.AppID)
argIndex++
}
if req.Status != nil {
whereClause += fmt.Sprintf(" AND status = $%d", argIndex)
args = append(args, *req.Status)
argIndex++
}
if req.SessionType != nil {
whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex)
args = append(args, *req.SessionType)
argIndex++
}
if req.TenantID != "" {
whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex)
args = append(args, req.TenantID)
argIndex++
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause)
var total int
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
r.logger.Error("Failed to get session count", zap.Error(err))
return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err)
}
// Get sessions with pagination
query := fmt.Sprintf(`
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
%s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1)
args = append(args, req.Limit, req.Offset)
sessions, err := r.scanSessions(ctx, query, args...)
if err != nil {
return nil, err
}
return &domain.SessionListResponse{
Sessions: sessions,
Total: total,
Limit: req.Limit,
Offset: req.Offset,
}, nil
}
// Update updates an existing session
func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
r.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
// Build UPDATE clause dynamically
setParts := []string{"updated_at = NOW()"}
args := []interface{}{}
argIndex := 1
if updates.Status != nil {
setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex))
args = append(args, *updates.Status)
argIndex++
}
if updates.LastActivity != nil {
setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex))
args = append(args, *updates.LastActivity)
argIndex++
}
if updates.ExpiresAt != nil {
setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex))
args = append(args, *updates.ExpiresAt)
argIndex++
}
if updates.IPAddress != nil {
setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex))
args = append(args, *updates.IPAddress)
argIndex++
}
if updates.UserAgent != nil {
setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex))
args = append(args, *updates.UserAgent)
argIndex++
}
if len(setParts) == 1 {
return errors.NewValidationError("No fields to update")
}
// Build the complete query
setClause := fmt.Sprintf("%s", setParts[0])
for i := 1; i < len(setParts); i++ {
setClause += fmt.Sprintf(", %s", setParts[i])
}
query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex)
args = append(args, sessionID)
result, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
r.logger.Error("Failed to update session", zap.Error(err))
return errors.NewInternalError("Failed to update session").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
return nil
}
// UpdateActivity updates the last activity timestamp for a session
func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error {
r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1`
result, err := r.db.ExecContext(ctx, query, sessionID)
if err != nil {
r.logger.Error("Failed to update session activity", zap.Error(err))
return errors.NewInternalError("Failed to update session activity").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
return nil
}
// Revoke revokes a session
func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
r.logger.Debug("Revoking session",
zap.String("session_id", sessionID.String()),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE id = $3`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID)
if err != nil {
r.logger.Error("Failed to revoke session", zap.Error(err))
return errors.NewInternalError("Failed to revoke session").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
return nil
}
// RevokeAllByUser revokes all sessions for a user
func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error {
r.logger.Debug("Revoking all sessions for user",
zap.String("user_id", userID),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE user_id = $3 AND status = $4`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to revoke user sessions", zap.Error(err))
return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("User sessions revoked",
zap.String("user_id", userID),
zap.Int64("sessions_revoked", rowsAffected))
return nil
}
// RevokeAllByUserAndApp revokes all sessions for a user and application
func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error {
r.logger.Debug("Revoking all sessions for user and app",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE user_id = $3 AND app_id = $4 AND status = $5`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to revoke user app sessions", zap.Error(err))
return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("User app sessions revoked",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.Int64("sessions_revoked", rowsAffected))
return nil
}
// ExpireOldSessions marks expired sessions as expired
func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) {
r.logger.Debug("Expiring old sessions")
query := `
UPDATE user_sessions
SET status = $1, updated_at = NOW()
WHERE expires_at < NOW() AND status = $2`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to expire old sessions", zap.Error(err))
return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected))
return int(rowsAffected), nil
}
// DeleteExpiredSessions removes expired sessions older than the specified duration
func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) {
r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan))
cutoffTime := time.Now().Add(-olderThan)
query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime)
if err != nil {
r.logger.Error("Failed to delete expired sessions", zap.Error(err))
return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected))
return int(rowsAffected), nil
}
// Exists checks if a session exists
func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) {
r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String()))
query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)`
var exists bool
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists)
if err != nil {
r.logger.Error("Failed to check session existence", zap.Error(err))
return false, errors.NewInternalError("Failed to check session existence").WithInternal(err)
}
return exists, nil
}
// GetSessionCount returns the total number of sessions for a user
func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) {
r.logger.Debug("Getting session count for user", zap.String("user_id", userID))
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1`
var count int
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
if err != nil {
r.logger.Error("Failed to get session count", zap.Error(err))
return 0, errors.NewInternalError("Failed to get session count").WithInternal(err)
}
return count, nil
}
// GetActiveSessionCount returns the number of active sessions for a user
func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) {
r.logger.Debug("Getting active session count for user", zap.String("user_id", userID))
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()`
var count int
err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count)
if err != nil {
r.logger.Error("Failed to get active session count", zap.Error(err))
return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err)
}
return count, nil
}
// scanSessions is a helper method to scan multiple sessions from query results
func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) {
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
r.logger.Error("Failed to execute session query", zap.Error(err))
return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err)
}
defer rows.Close()
var sessions []*domain.UserSession
for rows.Next() {
var session domain.UserSession
var metadataJSON []byte
var revokedAt sql.NullTime
var revokedBy sql.NullString
err := rows.Scan(
&session.ID,
&session.UserID,
&session.AppID,
&session.SessionType,
&session.Status,
&session.AccessToken,
&session.RefreshToken,
&session.IDToken,
&session.IPAddress,
&session.UserAgent,
&session.LastActivity,
&session.ExpiresAt,
&session.CreatedAt,
&session.UpdatedAt,
&revokedAt,
&revokedBy,
&metadataJSON,
)
if err != nil {
r.logger.Error("Failed to scan session row", zap.Error(err))
return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err)
}
// Handle nullable fields
if revokedAt.Valid {
session.RevokedAt = &revokedAt.Time
}
if revokedBy.Valid {
session.RevokedBy = &revokedBy.String
}
// Deserialize metadata
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
}
sessions = append(sessions, &session)
}
if err := rows.Err(); err != nil {
r.logger.Error("Error iterating session rows", zap.Error(err))
return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err)
}
return sessions, nil
}

View File

@ -67,3 +67,54 @@ type AuthenticationService interface {
// RefreshJWTToken refreshes an existing JWT token
RefreshJWTToken(ctx context.Context, tokenString string, newExpiration time.Time) (string, error)
}
// SessionService defines the interface for session management business logic
type SessionService interface {
// CreateSession creates a new user session
CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error)
// GetSession retrieves a session by its ID
GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
// GetUserSessions retrieves all sessions for a user
GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error)
// GetUserAppSessions retrieves sessions for a specific user and application
GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
// GetActiveSessions retrieves all active sessions for a user
GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error)
// ListSessions retrieves sessions with filtering and pagination
ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
// UpdateSession updates an existing session
UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
// UpdateSessionActivity updates the last activity timestamp for a session
UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error
// RevokeSession revokes a specific session
RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
// RevokeUserSessions revokes all sessions for a user
RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error
// RevokeUserAppSessions revokes all sessions for a user and application
RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error
// ValidateSession validates if a session is active and valid
ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
// RefreshSession refreshes a session's expiration time
RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error
// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error)
// GetSessionStats returns session statistics for a user
GetSessionStats(ctx context.Context, userID string) (total int, active int, err error)
// CreateOAuth2Session creates a session from OAuth2 authentication flow
CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error)
}

View File

@ -0,0 +1,414 @@
package services
import (
"context"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/repository"
)
// sessionService implements the SessionService interface
type sessionService struct {
sessionRepo repository.SessionRepository
appRepo repository.ApplicationRepository
config config.ConfigProvider
logger *zap.Logger
}
// NewSessionService creates a new session service
func NewSessionService(
sessionRepo repository.SessionRepository,
appRepo repository.ApplicationRepository,
config config.ConfigProvider,
logger *zap.Logger,
) SessionService {
return &sessionService{
sessionRepo: sessionRepo,
appRepo: appRepo,
config: config,
logger: logger,
}
}
// CreateSession creates a new user session
func (s *sessionService) CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error) {
s.logger.Debug("Creating new session",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.String("session_type", string(req.SessionType)))
// Validate application exists
app, err := s.appRepo.GetByID(ctx, req.AppID)
if err != nil {
if errors.IsNotFound(err) {
return nil, errors.NewValidationError("Application not found")
}
return nil, err
}
// Check if application supports user tokens
supportsUser := false
for _, appType := range app.Type {
if appType == domain.ApplicationTypeUser {
supportsUser = true
break
}
}
if !supportsUser {
return nil, errors.NewValidationError("Application does not support user sessions")
}
// Create session object
session := &domain.UserSession{
ID: uuid.New(),
UserID: req.UserID,
AppID: req.AppID,
SessionType: req.SessionType,
Status: domain.SessionStatusActive,
IPAddress: req.IPAddress,
UserAgent: req.UserAgent,
ExpiresAt: req.ExpiresAt,
Metadata: domain.SessionMetadata{
TenantID: req.TenantID,
Permissions: req.Permissions,
Claims: req.Claims,
LoginMethod: "oauth2",
},
}
// Create session in repository
if err := s.sessionRepo.Create(ctx, session); err != nil {
s.logger.Error("Failed to create session", zap.Error(err))
return nil, err
}
s.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
return session, nil
}
// GetSession retrieves a session by its ID
func (s *sessionService) GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
s.logger.Debug("Getting session", zap.String("session_id", sessionID.String()))
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return nil, err
}
return session, nil
}
// GetUserSessions retrieves all sessions for a user
func (s *sessionService) GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
s.logger.Debug("Getting user sessions", zap.String("user_id", userID))
sessions, err := s.sessionRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, err
}
return sessions, nil
}
// GetUserAppSessions retrieves sessions for a specific user and application
func (s *sessionService) GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
s.logger.Debug("Getting user app sessions",
zap.String("user_id", userID),
zap.String("app_id", appID))
sessions, err := s.sessionRepo.GetByUserAndApp(ctx, userID, appID)
if err != nil {
return nil, err
}
return sessions, nil
}
// GetActiveSessions retrieves all active sessions for a user
func (s *sessionService) GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
s.logger.Debug("Getting active sessions", zap.String("user_id", userID))
sessions, err := s.sessionRepo.GetActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
return sessions, nil
}
// ListSessions retrieves sessions with filtering and pagination
func (s *sessionService) ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
s.logger.Debug("Listing sessions",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.Int("limit", req.Limit),
zap.Int("offset", req.Offset))
// Set default pagination if not provided
if req.Limit <= 0 {
req.Limit = 50
}
if req.Limit > 100 {
req.Limit = 100
}
response, err := s.sessionRepo.List(ctx, req)
if err != nil {
return nil, err
}
return response, nil
}
// UpdateSession updates an existing session
func (s *sessionService) UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
s.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
// Validate session exists
_, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return err
}
// Update session
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
return err
}
s.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
return nil
}
// UpdateSessionActivity updates the last activity timestamp for a session
func (s *sessionService) UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error {
s.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
return err
}
return nil
}
// RevokeSession revokes a specific session
func (s *sessionService) RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
s.logger.Debug("Revoking session",
zap.String("session_id", sessionID.String()),
zap.String("revoked_by", revokedBy))
// Validate session exists and is active
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return err
}
if session.Status != domain.SessionStatusActive {
return errors.NewValidationError("Session is not active")
}
// Revoke session
if err := s.sessionRepo.Revoke(ctx, sessionID, revokedBy); err != nil {
return err
}
s.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
return nil
}
// RevokeUserSessions revokes all sessions for a user
func (s *sessionService) RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error {
s.logger.Debug("Revoking user sessions",
zap.String("user_id", userID),
zap.String("revoked_by", revokedBy))
if err := s.sessionRepo.RevokeAllByUser(ctx, userID, revokedBy); err != nil {
return err
}
s.logger.Debug("User sessions revoked successfully", zap.String("user_id", userID))
return nil
}
// RevokeUserAppSessions revokes all sessions for a user and application
func (s *sessionService) RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error {
s.logger.Debug("Revoking user app sessions",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("revoked_by", revokedBy))
if err := s.sessionRepo.RevokeAllByUserAndApp(ctx, userID, appID, revokedBy); err != nil {
return err
}
s.logger.Debug("User app sessions revoked successfully",
zap.String("user_id", userID),
zap.String("app_id", appID))
return nil
}
// ValidateSession validates if a session is active and valid
func (s *sessionService) ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
s.logger.Debug("Validating session", zap.String("session_id", sessionID.String()))
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return nil, err
}
// Check if session is active
if !session.IsActive() {
if session.IsExpired() {
return nil, errors.NewAuthenticationError("Session has expired")
}
if session.IsRevoked() {
return nil, errors.NewAuthenticationError("Session has been revoked")
}
return nil, errors.NewAuthenticationError("Session is not active")
}
// Update last activity
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
s.logger.Warn("Failed to update session activity", zap.Error(err))
// Don't fail validation if we can't update activity
}
s.logger.Debug("Session validated successfully", zap.String("session_id", sessionID.String()))
return session, nil
}
// RefreshSession refreshes a session's expiration time
func (s *sessionService) RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error {
s.logger.Debug("Refreshing session",
zap.String("session_id", sessionID.String()),
zap.Time("new_expiration", newExpiration))
// Validate session exists and is active
session, err := s.sessionRepo.GetByID(ctx, sessionID)
if err != nil {
return err
}
if !session.IsActive() {
return errors.NewValidationError("Cannot refresh inactive session")
}
// Update expiration
updates := &domain.UpdateSessionRequest{
ExpiresAt: &newExpiration,
}
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
return err
}
s.logger.Debug("Session refreshed successfully", zap.String("session_id", sessionID.String()))
return nil
}
// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
func (s *sessionService) CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error) {
s.logger.Debug("Cleaning up expired sessions")
// Mark expired sessions
expired, err = s.sessionRepo.ExpireOldSessions(ctx)
if err != nil {
s.logger.Error("Failed to expire old sessions", zap.Error(err))
return 0, 0, err
}
// Delete old expired sessions if requested
if deleteOlderThan != nil {
deleted, err = s.sessionRepo.DeleteExpiredSessions(ctx, *deleteOlderThan)
if err != nil {
s.logger.Error("Failed to delete expired sessions", zap.Error(err))
return expired, 0, err
}
}
s.logger.Debug("Session cleanup completed",
zap.Int("expired", expired),
zap.Int("deleted", deleted))
return expired, deleted, nil
}
// GetSessionStats returns session statistics for a user
func (s *sessionService) GetSessionStats(ctx context.Context, userID string) (total int, active int, err error) {
s.logger.Debug("Getting session stats", zap.String("user_id", userID))
total, err = s.sessionRepo.GetSessionCount(ctx, userID)
if err != nil {
return 0, 0, err
}
active, err = s.sessionRepo.GetActiveSessionCount(ctx, userID)
if err != nil {
return 0, 0, err
}
return total, active, nil
}
// CreateOAuth2Session creates a session from OAuth2 authentication flow
func (s *sessionService) CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error) {
s.logger.Debug("Creating OAuth2 session",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("session_type", string(sessionType)))
// Validate application exists
app, err := s.appRepo.GetByID(ctx, appID)
if err != nil {
if errors.IsNotFound(err) {
return nil, errors.NewValidationError("Application not found")
}
return nil, err
}
// Calculate expiration based on token response
expiresAt := time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second)
// Use application's max token duration if shorter
maxExpiration := time.Now().Add(app.MaxTokenDuration)
if expiresAt.After(maxExpiration) {
expiresAt = maxExpiration
}
// Create session object
session := &domain.UserSession{
ID: uuid.New(),
UserID: userID,
AppID: appID,
SessionType: sessionType,
Status: domain.SessionStatusActive,
AccessToken: tokenResponse.AccessToken, // In production, encrypt this
RefreshToken: tokenResponse.RefreshToken, // In production, encrypt this
IDToken: tokenResponse.IDToken, // In production, encrypt this
IPAddress: ipAddress,
UserAgent: userAgent,
ExpiresAt: expiresAt,
Metadata: domain.SessionMetadata{
LoginMethod: "oauth2",
Claims: map[string]string{
"sub": userInfo.Sub,
"email": userInfo.Email,
"name": userInfo.Name,
},
},
}
// Create session in repository
if err := s.sessionRepo.Create(ctx, session); err != nil {
s.logger.Error("Failed to create OAuth2 session", zap.Error(err))
return nil, err
}
s.logger.Debug("OAuth2 session created successfully", zap.String("session_id", session.ID.String()))
return session, nil
}

View File

@ -0,0 +1,14 @@
-- Drop user_sessions table and related objects
DROP INDEX IF EXISTS idx_user_sessions_tenant;
DROP INDEX IF EXISTS idx_user_sessions_metadata;
DROP INDEX IF EXISTS idx_user_sessions_active_expires;
DROP INDEX IF EXISTS idx_user_sessions_active;
DROP INDEX IF EXISTS idx_user_sessions_created_at;
DROP INDEX IF EXISTS idx_user_sessions_last_activity;
DROP INDEX IF EXISTS idx_user_sessions_expires_at;
DROP INDEX IF EXISTS idx_user_sessions_status;
DROP INDEX IF EXISTS idx_user_sessions_user_app;
DROP INDEX IF EXISTS idx_user_sessions_app_id;
DROP INDEX IF EXISTS idx_user_sessions_user_id;
DROP TABLE IF EXISTS user_sessions;

View File

@ -0,0 +1,60 @@
-- Create user_sessions table for session management
CREATE TABLE IF NOT EXISTS user_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(255) NOT NULL,
app_id VARCHAR(255) NOT NULL,
session_type VARCHAR(20) NOT NULL CHECK (session_type IN ('web', 'mobile', 'api')),
status VARCHAR(20) NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'expired', 'revoked', 'suspended')),
access_token TEXT,
refresh_token TEXT,
id_token TEXT,
ip_address INET,
user_agent TEXT,
last_activity TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMP WITH TIME ZONE,
revoked_by VARCHAR(255),
metadata JSONB DEFAULT '{}',
-- Foreign key constraint to applications table
CONSTRAINT fk_user_sessions_app_id FOREIGN KEY (app_id) REFERENCES applications(app_id) ON DELETE CASCADE
);
-- Create indexes for performance
CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_user_sessions_app_id ON user_sessions(app_id);
CREATE INDEX IF NOT EXISTS idx_user_sessions_user_app ON user_sessions(user_id, app_id);
CREATE INDEX IF NOT EXISTS idx_user_sessions_status ON user_sessions(status);
CREATE INDEX IF NOT EXISTS idx_user_sessions_expires_at ON user_sessions(expires_at);
CREATE INDEX IF NOT EXISTS idx_user_sessions_last_activity ON user_sessions(last_activity);
CREATE INDEX IF NOT EXISTS idx_user_sessions_created_at ON user_sessions(created_at);
-- Create partial indexes for active sessions (most common queries)
CREATE INDEX IF NOT EXISTS idx_user_sessions_active ON user_sessions(user_id, app_id) WHERE status = 'active';
CREATE INDEX IF NOT EXISTS idx_user_sessions_active_expires ON user_sessions(user_id, expires_at) WHERE status = 'active';
-- Create GIN index for metadata JSONB queries
CREATE INDEX IF NOT EXISTS idx_user_sessions_metadata ON user_sessions USING GIN (metadata);
-- Create index for tenant-based queries (if using multi-tenancy)
CREATE INDEX IF NOT EXISTS idx_user_sessions_tenant ON user_sessions((metadata->>'tenant_id')) WHERE metadata->>'tenant_id' IS NOT NULL;
-- Add comments for documentation
COMMENT ON TABLE user_sessions IS 'Stores user session information for authentication and session management';
COMMENT ON COLUMN user_sessions.id IS 'Unique session identifier';
COMMENT ON COLUMN user_sessions.user_id IS 'User identifier (email, username, or external ID)';
COMMENT ON COLUMN user_sessions.app_id IS 'Application identifier this session belongs to';
COMMENT ON COLUMN user_sessions.session_type IS 'Type of session: web, mobile, or api';
COMMENT ON COLUMN user_sessions.status IS 'Current session status: active, expired, revoked, or suspended';
COMMENT ON COLUMN user_sessions.access_token IS 'OAuth2/OIDC access token (encrypted/hashed)';
COMMENT ON COLUMN user_sessions.refresh_token IS 'OAuth2/OIDC refresh token (encrypted/hashed)';
COMMENT ON COLUMN user_sessions.id_token IS 'OIDC ID token (encrypted/hashed)';
COMMENT ON COLUMN user_sessions.ip_address IS 'IP address of the client when session was created';
COMMENT ON COLUMN user_sessions.user_agent IS 'User agent string of the client';
COMMENT ON COLUMN user_sessions.last_activity IS 'Timestamp of last session activity';
COMMENT ON COLUMN user_sessions.expires_at IS 'When the session expires';
COMMENT ON COLUMN user_sessions.revoked_at IS 'When the session was revoked (if applicable)';
COMMENT ON COLUMN user_sessions.revoked_by IS 'Who revoked the session (user ID or system)';
COMMENT ON COLUMN user_sessions.metadata IS 'Additional session metadata (device info, location, etc.)';

532
test/saml_test.go Normal file
View File

@ -0,0 +1,532 @@
package test
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/domain"
)
// mockSAMLMetadata returns a mock SAML IdP metadata XML
func mockSAMLMetadata() string {
return `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="https://idp.example.com">
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:KeyDescriptor use="signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>MIICertificateData</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>`
}
// mockSAMLResponse returns a mock SAML response XML with current timestamps
func mockSAMLResponse() string {
now := time.Now().UTC()
issueInstant := now.Format(time.RFC3339)
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="_response_id" Version="2.0" IssueInstant="%s"
Destination="https://sp.example.com/acs" InResponseTo="_request_id">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<samlp:Status>
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
</samlp:Status>
<saml:Assertion ID="_assertion_id" Version="2.0" IssueInstant="%s">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<saml:Subject>
<saml:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress">user@example.com</saml:NameID>
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
<saml:SubjectConfirmationData InResponseTo="_request_id" NotOnOrAfter="%s" Recipient="https://sp.example.com/acs"/>
</saml:SubjectConfirmation>
</saml:Subject>
<saml:Conditions NotBefore="%s" NotOnOrAfter="%s">
<saml:AudienceRestriction>
<saml:Audience>https://sp.example.com</saml:Audience>
</saml:AudienceRestriction>
</saml:Conditions>
<saml:AttributeStatement>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress">
<saml:AttributeValue>user@example.com</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name">
<saml:AttributeValue>Test User</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname">
<saml:AttributeValue>Test</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname">
<saml:AttributeValue>User</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.microsoft.com/ws/2008/06/identity/claims/role">
<saml:AttributeValue>admin,user</saml:AttributeValue>
</saml:Attribute>
</saml:AttributeStatement>
<saml:AuthnStatement AuthnInstant="%s" SessionIndex="_session_index">
<saml:AuthnContext>
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
</saml:AuthnContext>
</saml:AuthnStatement>
</saml:Assertion>
</samlp:Response>`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
}
func TestSAMLProvider_GetMetadata(t *testing.T) {
tests := []struct {
name string
metadataURL string
serverResponse string
serverStatus int
expectError bool
errorContains string
}{
{
name: "successful metadata fetch",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: mockSAMLMetadata(),
serverStatus: http.StatusOK,
expectError: false,
},
{
name: "missing metadata URL",
metadataURL: "",
expectError: true,
errorContains: "SAML_IDP_METADATA_URL not configured",
},
{
name: "server error",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverStatus: http.StatusInternalServerError,
expectError: true,
errorContains: "returned status 500",
},
{
name: "invalid XML",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: "invalid xml",
serverStatus: http.StatusOK,
expectError: true,
errorContains: "Failed to parse SAML metadata",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server
var server *httptest.Server
if tt.metadataURL != "" && tt.serverStatus > 0 {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.serverStatus)
if tt.serverResponse != "" {
w.Write([]byte(tt.serverResponse))
}
}))
defer server.Close()
tt.metadataURL = server.URL
}
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GetMetadata
ctx := context.Background()
metadata, err := provider.GetMetadata(ctx)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, metadata)
} else {
assert.NoError(t, err)
assert.NotNil(t, metadata)
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
}
})
}
}
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
relayState string
expectError bool
errorContains string
}{
{
name: "successful auth request generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
relayState: "test-relay-state",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateAuthRequest
ctx := context.Background()
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, authURL)
assert.Empty(t, requestID)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, authURL)
assert.NotEmpty(t, requestID)
assert.Contains(t, authURL, "https://idp.example.com/sso")
assert.Contains(t, authURL, "SAMLRequest=")
if tt.relayState != "" {
assert.Contains(t, authURL, "RelayState="+tt.relayState)
}
}
})
}
}
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
tests := []struct {
name string
samlResponse string
expectedRequestID string
spEntityID string
expectError bool
errorContains string
expectedUserID string
expectedEmail string
expectedName string
expectedRoles []string
}{
{
name: "successful SAML response processing",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
expectedRequestID: "_request_id",
spEntityID: "https://sp.example.com",
expectedUserID: "user@example.com",
expectedEmail: "user@example.com",
expectedName: "Test User",
expectedRoles: []string{"admin", "user"},
},
{
name: "invalid base64 encoding",
samlResponse: "invalid-base64",
expectError: true,
errorContains: "Failed to decode SAML response",
},
{
name: "invalid XML",
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
expectError: true,
errorContains: "Failed to parse SAML response",
},
{
name: "audience mismatch",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
spEntityID: "https://wrong-sp.example.com",
expectError: true,
errorContains: "audience mismatch",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
assert.Equal(t, tt.expectedUserID, authContext.UserID)
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
// Check claims
if tt.expectedEmail != "" {
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
}
if tt.expectedName != "" {
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
}
// Check permissions/roles
if len(tt.expectedRoles) > 0 {
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
}
}
})
}
}
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
expectError bool
errorContains string
}{
{
name: "successful SP metadata generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateServiceProviderMetadata
metadata, err := provider.GenerateServiceProviderMetadata()
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, metadata)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, metadata)
assert.Contains(t, metadata, tt.spEntityID)
assert.Contains(t, metadata, tt.acsURL)
assert.Contains(t, metadata, "EntityDescriptor")
assert.Contains(t, metadata, "SPSSODescriptor")
}
})
}
}
// Benchmark tests for SAML operations
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
// Prepare SAML response
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
if err != nil {
b.Fatal(err)
}
}
}
// Test helper functions
func TestSAMLResponseValidation(t *testing.T) {
// Test various SAML response validation scenarios
tests := []struct {
name string
modifyXML func(string) string
expectError bool
errorContains string
}{
{
name: "expired assertion",
modifyXML: func(xml string) string {
// Replace all NotOnOrAfter timestamps with past time
pastTime := "2020-01-01T13:00:00Z"
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
},
expectError: true,
errorContains: "assertion has expired",
},
{
name: "assertion not yet valid",
modifyXML: func(xml string) string {
// Replace all NotBefore timestamps with future time
futureTime := "2030-01-01T11:55:00Z"
re := regexp.MustCompile(`NotBefore="[^"]*"`)
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
},
expectError: true,
errorContains: "assertion not yet valid",
},
{
name: "failed status",
modifyXML: func(xml string) string {
return strings.ReplaceAll(xml,
"urn:oasis:names:tc:SAML:2.0:status:Success",
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
},
expectError: true,
errorContains: "SAML authentication failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Modify SAML response
modifiedXML := tt.modifyXML(mockSAMLResponse())
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
}
})
}
}

View File

@ -0,0 +1,705 @@
package test
import (
"context"
"database/sql"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
"github.com/kms/api-key-service/internal/repository/postgres"
)
// SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing
type SQLMockDatabaseProvider struct {
db *sql.DB
}
func (m *SQLMockDatabaseProvider) GetDB() interface{} {
return m.db
}
func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error {
return m.db.PingContext(ctx)
}
func (m *SQLMockDatabaseProvider) Close() error {
return m.db.Close()
}
func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &SQLMockTransactionProvider{tx: tx}, nil
}
func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error {
return nil
}
// SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing
type SQLMockTransactionProvider struct {
tx *sql.Tx
}
func (m *SQLMockTransactionProvider) Commit() error {
return m.tx.Commit()
}
func (m *SQLMockTransactionProvider) Rollback() error {
return m.tx.Rollback()
}
func (m *SQLMockTransactionProvider) GetTx() interface{} {
return m.tx
}
func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
mockDB := &SQLMockDatabaseProvider{db: db}
repo := postgres.NewStaticTokenRepository(mockDB)
cleanup := func() {
db.Close()
}
return repo.(*postgres.StaticTokenRepository), mock, cleanup
}
func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
if err != nil {
b.Fatal(err)
}
mockDB := &SQLMockDatabaseProvider{db: db}
repo := postgres.NewStaticTokenRepository(mockDB)
cleanup := func() {
db.Close()
}
return repo.(*postgres.StaticTokenRepository), mock, cleanup
}
func TestStaticTokenRepository_Create(t *testing.T) {
tests := []struct {
name string
token *domain.StaticToken
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
}{
{
name: "successful creation",
token: &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: "hmac",
},
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
},
expectError: false,
},
{
name: "database error",
token: &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: "hmac",
},
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to create static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
err := repo.Create(ctx, tt.token)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
assert.NotZero(t, tt.token.CreatedAt)
assert.NotZero(t, tt.token.UpdatedAt)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByID(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedToken *domain.StaticToken
}{
{
name: "successful retrieval",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "individual", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
},
expectError: false,
expectedToken: &domain.StaticToken{
ID: tokenID,
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: string(domain.TokenTypeUser),
CreatedAt: now,
UpdatedAt: now,
},
},
{
name: "token not found",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrNoRows)
},
expectError: true,
errorMsg: "not found",
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to get static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
token, err := repo.GetByID(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, token)
} else {
assert.NoError(t, err)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedToken.ID, token.ID)
assert.Equal(t, tt.expectedToken.AppID, token.AppID)
assert.Equal(t, tt.expectedToken.Owner, token.Owner)
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
assert.Equal(t, tt.expectedToken.Type, token.Type)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByKeyHash(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
keyHash := "test-hash"
tests := []struct {
name string
keyHash string
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedToken *domain.StaticToken
}{
{
name: "successful retrieval",
keyHash: keyHash,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "individual", "test-user", "test-owner",
keyHash, "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
WithArgs(keyHash).
WillReturnRows(rows)
},
expectError: false,
expectedToken: &domain.StaticToken{
ID: tokenID,
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: keyHash,
Type: string(domain.TokenTypeUser),
CreatedAt: now,
UpdatedAt: now,
},
},
{
name: "token not found",
keyHash: keyHash,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
WithArgs(keyHash).
WillReturnError(sql.ErrNoRows)
},
expectError: true,
errorMsg: "not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
token, err := repo.GetByKeyHash(ctx, tt.keyHash)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, token)
} else {
assert.NoError(t, err)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByAppID(t *testing.T) {
tokenID1 := uuid.New()
tokenID2 := uuid.New()
now := time.Now()
appID := "test-app"
tests := []struct {
name string
appID string
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedCount int
}{
{
name: "successful retrieval with multiple tokens",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID1, appID, "user", "test-user1", "test-owner1",
"test-hash1", "user", now, now,
).AddRow(
tokenID2, appID, "user", "test-user2", "test-owner2",
"test-hash2", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 2,
},
{
name: "no tokens found",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
})
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 0,
},
{
name: "database error",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to query static tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
tokens, err := repo.GetByAppID(ctx, tt.appID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.Len(t, tokens, tt.expectedCount)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_List(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
tests := []struct {
name string
limit int
offset int
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedCount int
}{
{
name: "successful list with pagination",
limit: 10,
offset: 0,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "user", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
WithArgs(10, 0).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 1,
},
{
name: "database error",
limit: 10,
offset: 0,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
WithArgs(10, 0).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to query static tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
tokens, err := repo.List(ctx, tt.limit, tt.offset)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.Len(t, tokens, tt.expectedCount)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_Delete(t *testing.T) {
tokenID := uuid.New()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
}{
{
name: "successful deletion",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnResult(sqlmock.NewResult(0, 1))
},
expectError: false,
},
{
name: "token not found",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnResult(sqlmock.NewResult(0, 0))
},
expectError: true,
errorMsg: "not found",
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to delete static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
err := repo.Delete(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_Exists(t *testing.T) {
tokenID := uuid.New()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedExists bool
}{
{
name: "token exists",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"exists"}).AddRow(1)
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
},
expectError: false,
expectedExists: true,
},
{
name: "token does not exist",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrNoRows)
},
expectError: false,
expectedExists: false,
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to check static token existence",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
exists, err := repo.Exists(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedExists, exists)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
// Benchmark tests for repository operations
func BenchmarkStaticTokenRepository_Create(b *testing.B) {
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
defer cleanup()
token := &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: string(domain.TokenTypeUser),
}
// Setup mock expectations for all iterations
for i := 0; i < b.N; i++ {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
token.ID = uuid.New() // Generate new ID for each iteration
err := repo.Create(ctx, token)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkStaticTokenRepository_GetByID(b *testing.B) {
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
defer cleanup()
tokenID := uuid.New()
now := time.Now()
// Setup mock expectations for all iterations
for i := 0; i < b.N; i++ {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "user", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := repo.GetByID(ctx, tokenID)
if err != nil {
b.Fatal(err)
}
}
}