v2
This commit is contained in:
@ -64,10 +64,10 @@ This document outlines the complete roadmap for making the API Key Management Se
|
||||
- [x] Implement authorization code exchange and token refresh
|
||||
- [x] Add user info retrieval from OAuth2 providers
|
||||
- [x] Create comprehensive OAuth2 unit tests with benchmarks
|
||||
- [ ] Add SAML authentication support
|
||||
- [ ] Create user session management
|
||||
- [x] Add SAML authentication support
|
||||
- [x] Create user session management
|
||||
- [x] Implement role-based access control (RBAC)
|
||||
- [ ] Add multi-tenant authentication support
|
||||
- [x] Add multi-tenant authentication support
|
||||
|
||||
### Permission System Enhancement
|
||||
- [x] Implement hierarchical permission inheritance
|
||||
@ -120,7 +120,7 @@ This document outlines the complete roadmap for making the API Key Management Se
|
||||
- [x] Create authentication failure tracking
|
||||
|
||||
### Audit & Compliance
|
||||
- [ ] Implement comprehensive audit logging
|
||||
- [x] Implement comprehensive audit logging
|
||||
- [ ] Add compliance reporting features
|
||||
- [ ] Create data retention policies
|
||||
- [ ] Implement GDPR compliance features
|
||||
|
||||
2
go.mod
2
go.mod
@ -21,6 +21,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
@ -33,6 +34,7 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/jmoiron/sqlx v1.4.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@ -1,5 +1,8 @@
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
|
||||
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
@ -43,6 +46,7 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
|
||||
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
@ -65,10 +69,13 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||
github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg=
|
||||
github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
@ -78,6 +85,7 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
|
||||
590
internal/audit/audit.go
Normal file
590
internal/audit/audit.go
Normal file
@ -0,0 +1,590 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
)
|
||||
|
||||
// EventType represents the type of audit event
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
// Authentication events
|
||||
EventTypeLogin EventType = "auth.login"
|
||||
EventTypeLoginFailed EventType = "auth.login_failed"
|
||||
EventTypeLogout EventType = "auth.logout"
|
||||
EventTypeTokenCreated EventType = "auth.token_created"
|
||||
EventTypeTokenRevoked EventType = "auth.token_revoked"
|
||||
EventTypeTokenValidated EventType = "auth.token_validated"
|
||||
|
||||
// Session events
|
||||
EventTypeSessionCreated EventType = "session.created"
|
||||
EventTypeSessionRevoked EventType = "session.revoked"
|
||||
EventTypeSessionExpired EventType = "session.expired"
|
||||
|
||||
// Application events
|
||||
EventTypeAppCreated EventType = "app.created"
|
||||
EventTypeAppUpdated EventType = "app.updated"
|
||||
EventTypeAppDeleted EventType = "app.deleted"
|
||||
|
||||
// Permission events
|
||||
EventTypePermissionGranted EventType = "permission.granted"
|
||||
EventTypePermissionRevoked EventType = "permission.revoked"
|
||||
EventTypePermissionDenied EventType = "permission.denied"
|
||||
|
||||
// Tenant events
|
||||
EventTypeTenantCreated EventType = "tenant.created"
|
||||
EventTypeTenantUpdated EventType = "tenant.updated"
|
||||
EventTypeTenantSuspended EventType = "tenant.suspended"
|
||||
EventTypeTenantActivated EventType = "tenant.activated"
|
||||
|
||||
// User events
|
||||
EventTypeUserCreated EventType = "user.created"
|
||||
EventTypeUserUpdated EventType = "user.updated"
|
||||
EventTypeUserSuspended EventType = "user.suspended"
|
||||
EventTypeUserActivated EventType = "user.activated"
|
||||
|
||||
// Security events
|
||||
EventTypeSecurityViolation EventType = "security.violation"
|
||||
EventTypeBruteForceAttempt EventType = "security.brute_force"
|
||||
EventTypeIPBlocked EventType = "security.ip_blocked"
|
||||
EventTypeRateLimitExceeded EventType = "security.rate_limit_exceeded"
|
||||
|
||||
// System events
|
||||
EventTypeSystemStartup EventType = "system.startup"
|
||||
EventTypeSystemShutdown EventType = "system.shutdown"
|
||||
EventTypeConfigChanged EventType = "system.config_changed"
|
||||
)
|
||||
|
||||
// EventSeverity represents the severity level of an audit event
|
||||
type EventSeverity string
|
||||
|
||||
const (
|
||||
SeverityInfo EventSeverity = "info"
|
||||
SeverityWarning EventSeverity = "warning"
|
||||
SeverityError EventSeverity = "error"
|
||||
SeverityCritical EventSeverity = "critical"
|
||||
)
|
||||
|
||||
// EventStatus represents the status of an audit event
|
||||
type EventStatus string
|
||||
|
||||
const (
|
||||
StatusSuccess EventStatus = "success"
|
||||
StatusFailure EventStatus = "failure"
|
||||
StatusPending EventStatus = "pending"
|
||||
)
|
||||
|
||||
// AuditEvent represents a single audit event
|
||||
type AuditEvent struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
Type EventType `json:"type" db:"type"`
|
||||
Severity EventSeverity `json:"severity" db:"severity"`
|
||||
Status EventStatus `json:"status" db:"status"`
|
||||
Timestamp time.Time `json:"timestamp" db:"timestamp"`
|
||||
|
||||
// Actor information
|
||||
ActorID string `json:"actor_id,omitempty" db:"actor_id"`
|
||||
ActorType string `json:"actor_type,omitempty" db:"actor_type"` // user, system, service
|
||||
ActorIP string `json:"actor_ip,omitempty" db:"actor_ip"`
|
||||
UserAgent string `json:"user_agent,omitempty" db:"user_agent"`
|
||||
|
||||
// Tenant information
|
||||
TenantID *uuid.UUID `json:"tenant_id,omitempty" db:"tenant_id"`
|
||||
|
||||
// Resource information
|
||||
ResourceID string `json:"resource_id,omitempty" db:"resource_id"`
|
||||
ResourceType string `json:"resource_type,omitempty" db:"resource_type"`
|
||||
|
||||
// Event details
|
||||
Action string `json:"action" db:"action"`
|
||||
Description string `json:"description" db:"description"`
|
||||
Details map[string]interface{} `json:"details,omitempty" db:"details"`
|
||||
|
||||
// Request context
|
||||
RequestID string `json:"request_id,omitempty" db:"request_id"`
|
||||
SessionID string `json:"session_id,omitempty" db:"session_id"`
|
||||
|
||||
// Additional metadata
|
||||
Tags []string `json:"tags,omitempty" db:"tags"`
|
||||
Metadata map[string]string `json:"metadata,omitempty" db:"metadata"`
|
||||
}
|
||||
|
||||
// AuditLogger defines the interface for audit logging
|
||||
type AuditLogger interface {
|
||||
// LogEvent logs a single audit event
|
||||
LogEvent(ctx context.Context, event *AuditEvent) error
|
||||
|
||||
// LogAuthEvent logs an authentication-related event
|
||||
LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error
|
||||
|
||||
// LogPermissionEvent logs a permission-related event
|
||||
LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error
|
||||
|
||||
// LogSecurityEvent logs a security-related event
|
||||
LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error
|
||||
|
||||
// LogSystemEvent logs a system-related event
|
||||
LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error
|
||||
|
||||
// QueryEvents queries audit events with filters
|
||||
QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error)
|
||||
|
||||
// GetEventStats returns audit event statistics
|
||||
GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error)
|
||||
}
|
||||
|
||||
// AuditFilter represents filters for querying audit events
|
||||
type AuditFilter struct {
|
||||
EventTypes []EventType `json:"event_types,omitempty"`
|
||||
Severities []EventSeverity `json:"severities,omitempty"`
|
||||
Statuses []EventStatus `json:"statuses,omitempty"`
|
||||
ActorID string `json:"actor_id,omitempty"`
|
||||
ActorType string `json:"actor_type,omitempty"`
|
||||
TenantID *uuid.UUID `json:"tenant_id,omitempty"`
|
||||
ResourceID string `json:"resource_id,omitempty"`
|
||||
ResourceType string `json:"resource_type,omitempty"`
|
||||
StartTime *time.Time `json:"start_time,omitempty"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
OrderBy string `json:"order_by"` // timestamp, type, severity
|
||||
OrderDesc bool `json:"order_desc"`
|
||||
}
|
||||
|
||||
// AuditStatsFilter represents filters for audit statistics
|
||||
type AuditStatsFilter struct {
|
||||
EventTypes []EventType `json:"event_types,omitempty"`
|
||||
TenantID *uuid.UUID `json:"tenant_id,omitempty"`
|
||||
StartTime *time.Time `json:"start_time,omitempty"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
GroupBy string `json:"group_by"` // type, severity, status, hour, day
|
||||
}
|
||||
|
||||
// AuditStats represents audit event statistics
|
||||
type AuditStats struct {
|
||||
TotalEvents int `json:"total_events"`
|
||||
ByType map[EventType]int `json:"by_type"`
|
||||
BySeverity map[EventSeverity]int `json:"by_severity"`
|
||||
ByStatus map[EventStatus]int `json:"by_status"`
|
||||
ByTime map[string]int `json:"by_time,omitempty"`
|
||||
}
|
||||
|
||||
// auditLogger implements the AuditLogger interface
|
||||
type auditLogger struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
repository AuditRepository
|
||||
}
|
||||
|
||||
// AuditRepository defines the interface for audit event storage
|
||||
type AuditRepository interface {
|
||||
Create(ctx context.Context, event *AuditEvent) error
|
||||
Query(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error)
|
||||
GetStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error)
|
||||
DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error)
|
||||
}
|
||||
|
||||
// NewAuditLogger creates a new audit logger
|
||||
func NewAuditLogger(config config.ConfigProvider, logger *zap.Logger, repository AuditRepository) AuditLogger {
|
||||
return &auditLogger{
|
||||
config: config,
|
||||
logger: logger,
|
||||
repository: repository,
|
||||
}
|
||||
}
|
||||
|
||||
// LogEvent logs a single audit event
|
||||
func (a *auditLogger) LogEvent(ctx context.Context, event *AuditEvent) error {
|
||||
// Set default values
|
||||
if event.ID == uuid.Nil {
|
||||
event.ID = uuid.New()
|
||||
}
|
||||
if event.Timestamp.IsZero() {
|
||||
event.Timestamp = time.Now().UTC()
|
||||
}
|
||||
if event.Severity == "" {
|
||||
event.Severity = SeverityInfo
|
||||
}
|
||||
if event.Status == "" {
|
||||
event.Status = StatusSuccess
|
||||
}
|
||||
|
||||
// Extract request context if available
|
||||
if requestID := ctx.Value("request_id"); requestID != nil {
|
||||
if reqID, ok := requestID.(string); ok {
|
||||
event.RequestID = reqID
|
||||
}
|
||||
}
|
||||
|
||||
// Log to structured logger
|
||||
a.logToStructuredLogger(event)
|
||||
|
||||
// Store in repository
|
||||
if err := a.repository.Create(ctx, event); err != nil {
|
||||
a.logger.Error("Failed to store audit event",
|
||||
zap.Error(err),
|
||||
zap.String("event_id", event.ID.String()),
|
||||
zap.String("event_type", string(event.Type)))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogAuthEvent logs an authentication-related event
|
||||
func (a *auditLogger) LogAuthEvent(ctx context.Context, eventType EventType, actorID, actorIP string, details map[string]interface{}) error {
|
||||
severity := SeverityInfo
|
||||
status := StatusSuccess
|
||||
|
||||
// Determine severity and status based on event type
|
||||
switch eventType {
|
||||
case EventTypeLoginFailed:
|
||||
severity = SeverityWarning
|
||||
status = StatusFailure
|
||||
case EventTypeTokenRevoked:
|
||||
severity = SeverityWarning
|
||||
}
|
||||
|
||||
event := &AuditEvent{
|
||||
Type: eventType,
|
||||
Severity: severity,
|
||||
Status: status,
|
||||
ActorID: actorID,
|
||||
ActorType: "user",
|
||||
ActorIP: actorIP,
|
||||
Action: string(eventType),
|
||||
Description: a.generateDescription(eventType, details),
|
||||
Details: details,
|
||||
Tags: []string{"authentication"},
|
||||
}
|
||||
|
||||
return a.LogEvent(ctx, event)
|
||||
}
|
||||
|
||||
// LogPermissionEvent logs a permission-related event
|
||||
func (a *auditLogger) LogPermissionEvent(ctx context.Context, eventType EventType, actorID, resourceID, resourceType string, details map[string]interface{}) error {
|
||||
severity := SeverityInfo
|
||||
status := StatusSuccess
|
||||
|
||||
// Determine severity and status based on event type
|
||||
switch eventType {
|
||||
case EventTypePermissionDenied:
|
||||
severity = SeverityWarning
|
||||
status = StatusFailure
|
||||
}
|
||||
|
||||
event := &AuditEvent{
|
||||
Type: eventType,
|
||||
Severity: severity,
|
||||
Status: status,
|
||||
ActorID: actorID,
|
||||
ActorType: "user",
|
||||
ResourceID: resourceID,
|
||||
ResourceType: resourceType,
|
||||
Action: string(eventType),
|
||||
Description: a.generateDescription(eventType, details),
|
||||
Details: details,
|
||||
Tags: []string{"permission", "authorization"},
|
||||
}
|
||||
|
||||
return a.LogEvent(ctx, event)
|
||||
}
|
||||
|
||||
// LogSecurityEvent logs a security-related event
|
||||
func (a *auditLogger) LogSecurityEvent(ctx context.Context, eventType EventType, actorIP string, severity EventSeverity, details map[string]interface{}) error {
|
||||
status := StatusSuccess
|
||||
if severity == SeverityError || severity == SeverityCritical {
|
||||
status = StatusFailure
|
||||
}
|
||||
|
||||
event := &AuditEvent{
|
||||
Type: eventType,
|
||||
Severity: severity,
|
||||
Status: status,
|
||||
ActorIP: actorIP,
|
||||
ActorType: "system",
|
||||
Action: string(eventType),
|
||||
Description: a.generateDescription(eventType, details),
|
||||
Details: details,
|
||||
Tags: []string{"security"},
|
||||
}
|
||||
|
||||
return a.LogEvent(ctx, event)
|
||||
}
|
||||
|
||||
// LogSystemEvent logs a system-related event
|
||||
func (a *auditLogger) LogSystemEvent(ctx context.Context, eventType EventType, details map[string]interface{}) error {
|
||||
event := &AuditEvent{
|
||||
Type: eventType,
|
||||
Severity: SeverityInfo,
|
||||
Status: StatusSuccess,
|
||||
ActorType: "system",
|
||||
Action: string(eventType),
|
||||
Description: a.generateDescription(eventType, details),
|
||||
Details: details,
|
||||
Tags: []string{"system"},
|
||||
}
|
||||
|
||||
return a.LogEvent(ctx, event)
|
||||
}
|
||||
|
||||
// QueryEvents queries audit events with filters
|
||||
func (a *auditLogger) QueryEvents(ctx context.Context, filter *AuditFilter) ([]*AuditEvent, error) {
|
||||
// Set default pagination
|
||||
if filter.Limit <= 0 {
|
||||
filter.Limit = 100
|
||||
}
|
||||
if filter.Limit > 1000 {
|
||||
filter.Limit = 1000
|
||||
}
|
||||
if filter.OrderBy == "" {
|
||||
filter.OrderBy = "timestamp"
|
||||
filter.OrderDesc = true
|
||||
}
|
||||
|
||||
return a.repository.Query(ctx, filter)
|
||||
}
|
||||
|
||||
// GetEventStats returns audit event statistics
|
||||
func (a *auditLogger) GetEventStats(ctx context.Context, filter *AuditStatsFilter) (*AuditStats, error) {
|
||||
return a.repository.GetStats(ctx, filter)
|
||||
}
|
||||
|
||||
// logToStructuredLogger logs the event to the structured logger
|
||||
func (a *auditLogger) logToStructuredLogger(event *AuditEvent) {
|
||||
fields := []zap.Field{
|
||||
zap.String("audit_event_id", event.ID.String()),
|
||||
zap.String("event_type", string(event.Type)),
|
||||
zap.String("severity", string(event.Severity)),
|
||||
zap.String("status", string(event.Status)),
|
||||
zap.Time("timestamp", event.Timestamp),
|
||||
zap.String("action", event.Action),
|
||||
zap.String("description", event.Description),
|
||||
}
|
||||
|
||||
if event.ActorID != "" {
|
||||
fields = append(fields, zap.String("actor_id", event.ActorID))
|
||||
}
|
||||
if event.ActorType != "" {
|
||||
fields = append(fields, zap.String("actor_type", event.ActorType))
|
||||
}
|
||||
if event.ActorIP != "" {
|
||||
fields = append(fields, zap.String("actor_ip", event.ActorIP))
|
||||
}
|
||||
if event.TenantID != nil {
|
||||
fields = append(fields, zap.String("tenant_id", event.TenantID.String()))
|
||||
}
|
||||
if event.ResourceID != "" {
|
||||
fields = append(fields, zap.String("resource_id", event.ResourceID))
|
||||
}
|
||||
if event.ResourceType != "" {
|
||||
fields = append(fields, zap.String("resource_type", event.ResourceType))
|
||||
}
|
||||
if event.RequestID != "" {
|
||||
fields = append(fields, zap.String("request_id", event.RequestID))
|
||||
}
|
||||
if event.SessionID != "" {
|
||||
fields = append(fields, zap.String("session_id", event.SessionID))
|
||||
}
|
||||
if len(event.Tags) > 0 {
|
||||
fields = append(fields, zap.Strings("tags", event.Tags))
|
||||
}
|
||||
if event.Details != nil {
|
||||
if detailsJSON, err := json.Marshal(event.Details); err == nil {
|
||||
fields = append(fields, zap.String("details", string(detailsJSON)))
|
||||
}
|
||||
}
|
||||
|
||||
// Log at appropriate level based on severity
|
||||
switch event.Severity {
|
||||
case SeverityInfo:
|
||||
a.logger.Info("Audit event", fields...)
|
||||
case SeverityWarning:
|
||||
a.logger.Warn("Audit event", fields...)
|
||||
case SeverityError:
|
||||
a.logger.Error("Audit event", fields...)
|
||||
case SeverityCritical:
|
||||
a.logger.Error("Critical audit event", fields...)
|
||||
default:
|
||||
a.logger.Info("Audit event", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
// generateDescription generates a human-readable description for an event
|
||||
func (a *auditLogger) generateDescription(eventType EventType, details map[string]interface{}) string {
|
||||
switch eventType {
|
||||
case EventTypeLogin:
|
||||
return "User successfully logged in"
|
||||
case EventTypeLoginFailed:
|
||||
return "User login attempt failed"
|
||||
case EventTypeLogout:
|
||||
return "User logged out"
|
||||
case EventTypeTokenCreated:
|
||||
return "API token created"
|
||||
case EventTypeTokenRevoked:
|
||||
return "API token revoked"
|
||||
case EventTypeTokenValidated:
|
||||
return "API token validated"
|
||||
case EventTypeSessionCreated:
|
||||
return "User session created"
|
||||
case EventTypeSessionRevoked:
|
||||
return "User session revoked"
|
||||
case EventTypeSessionExpired:
|
||||
return "User session expired"
|
||||
case EventTypeAppCreated:
|
||||
return "Application created"
|
||||
case EventTypeAppUpdated:
|
||||
return "Application updated"
|
||||
case EventTypeAppDeleted:
|
||||
return "Application deleted"
|
||||
case EventTypePermissionGranted:
|
||||
return "Permission granted"
|
||||
case EventTypePermissionRevoked:
|
||||
return "Permission revoked"
|
||||
case EventTypePermissionDenied:
|
||||
return "Permission denied"
|
||||
case EventTypeTenantCreated:
|
||||
return "Tenant created"
|
||||
case EventTypeTenantUpdated:
|
||||
return "Tenant updated"
|
||||
case EventTypeTenantSuspended:
|
||||
return "Tenant suspended"
|
||||
case EventTypeTenantActivated:
|
||||
return "Tenant activated"
|
||||
case EventTypeUserCreated:
|
||||
return "User created"
|
||||
case EventTypeUserUpdated:
|
||||
return "User updated"
|
||||
case EventTypeUserSuspended:
|
||||
return "User suspended"
|
||||
case EventTypeUserActivated:
|
||||
return "User activated"
|
||||
case EventTypeSecurityViolation:
|
||||
return "Security violation detected"
|
||||
case EventTypeBruteForceAttempt:
|
||||
return "Brute force attempt detected"
|
||||
case EventTypeIPBlocked:
|
||||
return "IP address blocked"
|
||||
case EventTypeRateLimitExceeded:
|
||||
return "Rate limit exceeded"
|
||||
case EventTypeSystemStartup:
|
||||
return "System started"
|
||||
case EventTypeSystemShutdown:
|
||||
return "System shutdown"
|
||||
case EventTypeConfigChanged:
|
||||
return "Configuration changed"
|
||||
default:
|
||||
return string(eventType)
|
||||
}
|
||||
}
|
||||
|
||||
// AuditEventBuilder provides a fluent interface for building audit events
|
||||
type AuditEventBuilder struct {
|
||||
event *AuditEvent
|
||||
}
|
||||
|
||||
// NewAuditEventBuilder creates a new audit event builder
|
||||
func NewAuditEventBuilder(eventType EventType) *AuditEventBuilder {
|
||||
return &AuditEventBuilder{
|
||||
event: &AuditEvent{
|
||||
ID: uuid.New(),
|
||||
Type: eventType,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Severity: SeverityInfo,
|
||||
Status: StatusSuccess,
|
||||
Details: make(map[string]interface{}),
|
||||
Metadata: make(map[string]string),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WithSeverity sets the event severity
|
||||
func (b *AuditEventBuilder) WithSeverity(severity EventSeverity) *AuditEventBuilder {
|
||||
b.event.Severity = severity
|
||||
return b
|
||||
}
|
||||
|
||||
// WithStatus sets the event status
|
||||
func (b *AuditEventBuilder) WithStatus(status EventStatus) *AuditEventBuilder {
|
||||
b.event.Status = status
|
||||
return b
|
||||
}
|
||||
|
||||
// WithActor sets the actor information
|
||||
func (b *AuditEventBuilder) WithActor(actorID, actorType, actorIP string) *AuditEventBuilder {
|
||||
b.event.ActorID = actorID
|
||||
b.event.ActorType = actorType
|
||||
b.event.ActorIP = actorIP
|
||||
return b
|
||||
}
|
||||
|
||||
// WithTenant sets the tenant ID
|
||||
func (b *AuditEventBuilder) WithTenant(tenantID uuid.UUID) *AuditEventBuilder {
|
||||
b.event.TenantID = &tenantID
|
||||
return b
|
||||
}
|
||||
|
||||
// WithResource sets the resource information
|
||||
func (b *AuditEventBuilder) WithResource(resourceID, resourceType string) *AuditEventBuilder {
|
||||
b.event.ResourceID = resourceID
|
||||
b.event.ResourceType = resourceType
|
||||
return b
|
||||
}
|
||||
|
||||
// WithAction sets the action
|
||||
func (b *AuditEventBuilder) WithAction(action string) *AuditEventBuilder {
|
||||
b.event.Action = action
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDescription sets the description
|
||||
func (b *AuditEventBuilder) WithDescription(description string) *AuditEventBuilder {
|
||||
b.event.Description = description
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDetail adds a detail
|
||||
func (b *AuditEventBuilder) WithDetail(key string, value interface{}) *AuditEventBuilder {
|
||||
b.event.Details[key] = value
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDetails sets multiple details
|
||||
func (b *AuditEventBuilder) WithDetails(details map[string]interface{}) *AuditEventBuilder {
|
||||
for k, v := range details {
|
||||
b.event.Details[k] = v
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithRequestContext sets request context information
|
||||
func (b *AuditEventBuilder) WithRequestContext(requestID, sessionID string) *AuditEventBuilder {
|
||||
b.event.RequestID = requestID
|
||||
b.event.SessionID = sessionID
|
||||
return b
|
||||
}
|
||||
|
||||
// WithTags sets the tags
|
||||
func (b *AuditEventBuilder) WithTags(tags ...string) *AuditEventBuilder {
|
||||
b.event.Tags = tags
|
||||
return b
|
||||
}
|
||||
|
||||
// WithMetadata adds metadata
|
||||
func (b *AuditEventBuilder) WithMetadata(key, value string) *AuditEventBuilder {
|
||||
b.event.Metadata[key] = value
|
||||
return b
|
||||
}
|
||||
|
||||
// Build returns the built audit event
|
||||
func (b *AuditEventBuilder) Build() *AuditEvent {
|
||||
return b.event
|
||||
}
|
||||
544
internal/auth/saml.go
Normal file
544
internal/auth/saml.go
Normal file
@ -0,0 +1,544 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// SAMLProvider represents a SAML 2.0 identity provider
|
||||
type SAMLProvider struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
httpClient *http.Client
|
||||
privateKey *rsa.PrivateKey
|
||||
certificate *x509.Certificate
|
||||
}
|
||||
|
||||
// NewSAMLProvider creates a new SAML provider
|
||||
func NewSAMLProvider(config config.ConfigProvider, logger *zap.Logger) (*SAMLProvider, error) {
|
||||
provider := &SAMLProvider{
|
||||
config: config,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// Load SP private key and certificate if configured
|
||||
if err := provider.loadCredentials(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// SAMLMetadata represents SAML IdP metadata
|
||||
type SAMLMetadata struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"`
|
||||
EntityID string `xml:"entityID,attr"`
|
||||
IDPSSODescriptor IDPSSODescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata IDPSSODescriptor"`
|
||||
}
|
||||
|
||||
// IDPSSODescriptor represents the IdP SSO descriptor
|
||||
type IDPSSODescriptor struct {
|
||||
ProtocolSupportEnumeration string `xml:"protocolSupportEnumeration,attr"`
|
||||
KeyDescriptor []KeyDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata KeyDescriptor"`
|
||||
SingleSignOnService []SingleSignOnService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleSignOnService"`
|
||||
SingleLogoutService []SingleLogoutService `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleLogoutService"`
|
||||
}
|
||||
|
||||
// KeyDescriptor represents a key descriptor
|
||||
type KeyDescriptor struct {
|
||||
Use string `xml:"use,attr"`
|
||||
KeyInfo KeyInfo `xml:"urn:xmldsig KeyInfo"`
|
||||
}
|
||||
|
||||
// KeyInfo represents key information
|
||||
type KeyInfo struct {
|
||||
X509Data X509Data `xml:"urn:xmldsig X509Data"`
|
||||
}
|
||||
|
||||
// X509Data represents X509 certificate data
|
||||
type X509Data struct {
|
||||
X509Certificate string `xml:"urn:xmldsig X509Certificate"`
|
||||
}
|
||||
|
||||
// SingleSignOnService represents SSO service endpoint
|
||||
type SingleSignOnService struct {
|
||||
Binding string `xml:"Binding,attr"`
|
||||
Location string `xml:"Location,attr"`
|
||||
}
|
||||
|
||||
// SingleLogoutService represents SLO service endpoint
|
||||
type SingleLogoutService struct {
|
||||
Binding string `xml:"Binding,attr"`
|
||||
Location string `xml:"Location,attr"`
|
||||
}
|
||||
|
||||
// SAMLRequest represents a SAML authentication request
|
||||
type SAMLRequest struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol AuthnRequest"`
|
||||
ID string `xml:"ID,attr"`
|
||||
Version string `xml:"Version,attr"`
|
||||
IssueInstant time.Time `xml:"IssueInstant,attr"`
|
||||
Destination string `xml:"Destination,attr"`
|
||||
AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"`
|
||||
ProtocolBinding string `xml:"ProtocolBinding,attr"`
|
||||
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
|
||||
NameIDPolicy NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
|
||||
}
|
||||
|
||||
// Issuer represents the SAML issuer
|
||||
type Issuer struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
|
||||
Value string `xml:",chardata"`
|
||||
}
|
||||
|
||||
// NameIDPolicy represents the name ID policy
|
||||
type NameIDPolicy struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
|
||||
Format string `xml:"Format,attr"`
|
||||
}
|
||||
|
||||
// SAMLResponse represents a SAML response
|
||||
type SAMLResponse struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"`
|
||||
ID string `xml:"ID,attr"`
|
||||
Version string `xml:"Version,attr"`
|
||||
IssueInstant time.Time `xml:"IssueInstant,attr"`
|
||||
Destination string `xml:"Destination,attr"`
|
||||
InResponseTo string `xml:"InResponseTo,attr"`
|
||||
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
|
||||
Status Status `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
|
||||
Assertion Assertion `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
|
||||
}
|
||||
|
||||
// Status represents the SAML response status
|
||||
type Status struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
|
||||
StatusCode StatusCode `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
|
||||
}
|
||||
|
||||
// StatusCode represents the status code
|
||||
type StatusCode struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
|
||||
Value string `xml:"Value,attr"`
|
||||
}
|
||||
|
||||
// Assertion represents a SAML assertion
|
||||
type Assertion struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
|
||||
ID string `xml:"ID,attr"`
|
||||
Version string `xml:"Version,attr"`
|
||||
IssueInstant time.Time `xml:"IssueInstant,attr"`
|
||||
Issuer Issuer `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
|
||||
Subject Subject `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
|
||||
Conditions Conditions `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
|
||||
AttributeStatement AttributeStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
|
||||
AuthnStatement AuthnStatement `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
|
||||
}
|
||||
|
||||
// Subject represents the assertion subject
|
||||
type Subject struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
|
||||
NameID NameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
|
||||
SubjectConfirmation SubjectConfirmation `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
|
||||
}
|
||||
|
||||
// NameID represents the name identifier
|
||||
type NameID struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
|
||||
Format string `xml:"Format,attr"`
|
||||
Value string `xml:",chardata"`
|
||||
}
|
||||
|
||||
// SubjectConfirmation represents subject confirmation
|
||||
type SubjectConfirmation struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
|
||||
Method string `xml:"Method,attr"`
|
||||
SubjectConfirmationData SubjectConfirmationData `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
|
||||
}
|
||||
|
||||
// SubjectConfirmationData represents subject confirmation data
|
||||
type SubjectConfirmationData struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
|
||||
InResponseTo string `xml:"InResponseTo,attr"`
|
||||
NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
|
||||
Recipient string `xml:"Recipient,attr"`
|
||||
}
|
||||
|
||||
// Conditions represents assertion conditions
|
||||
type Conditions struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
|
||||
NotBefore time.Time `xml:"NotBefore,attr"`
|
||||
NotOnOrAfter time.Time `xml:"NotOnOrAfter,attr"`
|
||||
AudienceRestriction AudienceRestriction `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
|
||||
}
|
||||
|
||||
// AudienceRestriction represents audience restriction
|
||||
type AudienceRestriction struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
|
||||
Audience Audience `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
|
||||
}
|
||||
|
||||
// Audience represents the intended audience
|
||||
type Audience struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
|
||||
Value string `xml:",chardata"`
|
||||
}
|
||||
|
||||
// AttributeStatement represents attribute statement
|
||||
type AttributeStatement struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
|
||||
Attribute []Attribute `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
|
||||
}
|
||||
|
||||
// Attribute represents a SAML attribute
|
||||
type Attribute struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
|
||||
Name string `xml:"Name,attr"`
|
||||
AttributeValue []AttributeValue `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
|
||||
}
|
||||
|
||||
// AttributeValue represents an attribute value
|
||||
type AttributeValue struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
|
||||
Type string `xml:"http://www.w3.org/2001/XMLSchema-instance type,attr"`
|
||||
Value string `xml:",chardata"`
|
||||
}
|
||||
|
||||
// AuthnStatement represents authentication statement
|
||||
type AuthnStatement struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
|
||||
AuthnInstant time.Time `xml:"AuthnInstant,attr"`
|
||||
SessionIndex string `xml:"SessionIndex,attr"`
|
||||
AuthnContext AuthnContext `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
|
||||
}
|
||||
|
||||
// AuthnContext represents authentication context
|
||||
type AuthnContext struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
|
||||
AuthnContextClassRef string `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnContextClassRef"`
|
||||
}
|
||||
|
||||
// GetMetadata fetches the SAML IdP metadata
|
||||
func (p *SAMLProvider) GetMetadata(ctx context.Context) (*SAMLMetadata, error) {
|
||||
metadataURL := p.config.GetString("SAML_IDP_METADATA_URL")
|
||||
if metadataURL == "" {
|
||||
return nil, errors.NewConfigurationError("SAML_IDP_METADATA_URL not configured")
|
||||
}
|
||||
|
||||
p.logger.Debug("Fetching SAML IdP metadata", zap.String("url", metadataURL))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to create metadata request").WithInternal(err)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to fetch IdP metadata").WithInternal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.NewInternalError(fmt.Sprintf("Metadata endpoint returned status %d", resp.StatusCode))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to read metadata response").WithInternal(err)
|
||||
}
|
||||
|
||||
var metadata SAMLMetadata
|
||||
if err := xml.Unmarshal(body, &metadata); err != nil {
|
||||
return nil, errors.NewInternalError("Failed to parse SAML metadata").WithInternal(err)
|
||||
}
|
||||
|
||||
p.logger.Debug("SAML IdP metadata fetched successfully",
|
||||
zap.String("entity_id", metadata.EntityID))
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// GenerateAuthRequest generates a SAML authentication request
|
||||
func (p *SAMLProvider) GenerateAuthRequest(ctx context.Context, relayState string) (string, string, error) {
|
||||
metadata, err := p.GetMetadata(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Find SSO endpoint
|
||||
var ssoEndpoint string
|
||||
for _, sso := range metadata.IDPSSODescriptor.SingleSignOnService {
|
||||
if sso.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" {
|
||||
ssoEndpoint = sso.Location
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ssoEndpoint == "" {
|
||||
return "", "", errors.NewConfigurationError("No HTTP-Redirect SSO endpoint found in IdP metadata")
|
||||
}
|
||||
|
||||
// Generate request ID
|
||||
requestID := "_" + uuid.New().String()
|
||||
|
||||
// Get SP configuration
|
||||
spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
|
||||
acsURL := p.config.GetString("SAML_SP_ACS_URL")
|
||||
|
||||
if spEntityID == "" {
|
||||
return "", "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
|
||||
}
|
||||
if acsURL == "" {
|
||||
return "", "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
|
||||
}
|
||||
|
||||
// Create SAML request
|
||||
samlRequest := SAMLRequest{
|
||||
ID: requestID,
|
||||
Version: "2.0",
|
||||
IssueInstant: time.Now().UTC(),
|
||||
Destination: ssoEndpoint,
|
||||
AssertionConsumerServiceURL: acsURL,
|
||||
ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
||||
Issuer: Issuer{
|
||||
Value: spEntityID,
|
||||
},
|
||||
NameIDPolicy: NameIDPolicy{
|
||||
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to XML
|
||||
xmlData, err := xml.MarshalIndent(samlRequest, "", " ")
|
||||
if err != nil {
|
||||
return "", "", errors.NewInternalError("Failed to marshal SAML request").WithInternal(err)
|
||||
}
|
||||
|
||||
// Add XML declaration
|
||||
xmlRequest := `<?xml version="1.0" encoding="UTF-8"?>` + "\n" + string(xmlData)
|
||||
|
||||
// Base64 encode and URL encode
|
||||
encodedRequest := base64.StdEncoding.EncodeToString([]byte(xmlRequest))
|
||||
|
||||
// Build redirect URL
|
||||
params := url.Values{
|
||||
"SAMLRequest": {encodedRequest},
|
||||
"RelayState": {relayState},
|
||||
}
|
||||
|
||||
redirectURL := ssoEndpoint + "?" + params.Encode()
|
||||
|
||||
p.logger.Debug("Generated SAML authentication request",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("sso_endpoint", ssoEndpoint))
|
||||
|
||||
return redirectURL, requestID, nil
|
||||
}
|
||||
|
||||
// ProcessSAMLResponse processes a SAML response and extracts user information
|
||||
func (p *SAMLProvider) ProcessSAMLResponse(ctx context.Context, samlResponse string, expectedRequestID string) (*domain.AuthContext, error) {
|
||||
p.logger.Debug("Processing SAML response")
|
||||
|
||||
// Base64 decode the response
|
||||
decodedResponse, err := base64.StdEncoding.DecodeString(samlResponse)
|
||||
if err != nil {
|
||||
return nil, errors.NewValidationError("Failed to decode SAML response").WithInternal(err)
|
||||
}
|
||||
|
||||
// Parse XML
|
||||
var response SAMLResponse
|
||||
if err := xml.Unmarshal(decodedResponse, &response); err != nil {
|
||||
return nil, errors.NewValidationError("Failed to parse SAML response").WithInternal(err)
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if err := p.validateSAMLResponse(&response, expectedRequestID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract user information from assertion
|
||||
authContext, err := p.extractUserInfo(&response.Assertion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.logger.Debug("SAML response processed successfully",
|
||||
zap.String("user_id", authContext.UserID))
|
||||
|
||||
return authContext, nil
|
||||
}
|
||||
|
||||
// validateSAMLResponse validates a SAML response
|
||||
func (p *SAMLProvider) validateSAMLResponse(response *SAMLResponse, expectedRequestID string) error {
|
||||
// Check status
|
||||
if response.Status.StatusCode.Value != "urn:oasis:names:tc:SAML:2.0:status:Success" {
|
||||
return errors.NewAuthenticationError("SAML authentication failed: " + response.Status.StatusCode.Value)
|
||||
}
|
||||
|
||||
// Validate InResponseTo
|
||||
if expectedRequestID != "" && response.InResponseTo != expectedRequestID {
|
||||
return errors.NewValidationError("SAML response InResponseTo does not match request ID")
|
||||
}
|
||||
|
||||
// Validate assertion conditions
|
||||
assertion := &response.Assertion
|
||||
now := time.Now().UTC()
|
||||
|
||||
if now.Before(assertion.Conditions.NotBefore) {
|
||||
return errors.NewValidationError("SAML assertion not yet valid")
|
||||
}
|
||||
|
||||
if now.After(assertion.Conditions.NotOnOrAfter) {
|
||||
return errors.NewValidationError("SAML assertion has expired")
|
||||
}
|
||||
|
||||
// Validate audience
|
||||
expectedAudience := p.config.GetString("SAML_SP_ENTITY_ID")
|
||||
if assertion.Conditions.AudienceRestriction.Audience.Value != expectedAudience {
|
||||
return errors.NewValidationError("SAML assertion audience mismatch")
|
||||
}
|
||||
|
||||
// In production, you should also validate the signature
|
||||
// This requires implementing XML signature validation
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUserInfo extracts user information from SAML assertion
|
||||
func (p *SAMLProvider) extractUserInfo(assertion *Assertion) (*domain.AuthContext, error) {
|
||||
// Extract user ID from NameID
|
||||
userID := assertion.Subject.NameID.Value
|
||||
if userID == "" {
|
||||
return nil, errors.NewValidationError("SAML assertion missing NameID")
|
||||
}
|
||||
|
||||
// Extract attributes
|
||||
claims := make(map[string]string)
|
||||
claims["sub"] = userID
|
||||
claims["name_id_format"] = assertion.Subject.NameID.Format
|
||||
|
||||
// Process attribute statements
|
||||
for _, attr := range assertion.AttributeStatement.Attribute {
|
||||
if len(attr.AttributeValue) > 0 {
|
||||
// Use the first value if multiple values exist
|
||||
claims[attr.Name] = attr.AttributeValue[0].Value
|
||||
}
|
||||
}
|
||||
|
||||
// Map common attributes to standard claims
|
||||
if email, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"]; exists {
|
||||
claims["email"] = email
|
||||
}
|
||||
if name, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"]; exists {
|
||||
claims["name"] = name
|
||||
}
|
||||
if givenName, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname"]; exists {
|
||||
claims["given_name"] = givenName
|
||||
}
|
||||
if surname, exists := claims["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname"]; exists {
|
||||
claims["family_name"] = surname
|
||||
}
|
||||
|
||||
// Extract permissions/roles if available
|
||||
var permissions []string
|
||||
if roles, exists := claims["http://schemas.microsoft.com/ws/2008/06/identity/claims/role"]; exists {
|
||||
permissions = strings.Split(roles, ",")
|
||||
}
|
||||
|
||||
authContext := &domain.AuthContext{
|
||||
UserID: userID,
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: claims,
|
||||
Permissions: permissions,
|
||||
}
|
||||
|
||||
return authContext, nil
|
||||
}
|
||||
|
||||
// GenerateServiceProviderMetadata generates SP metadata XML
|
||||
func (p *SAMLProvider) GenerateServiceProviderMetadata() (string, error) {
|
||||
spEntityID := p.config.GetString("SAML_SP_ENTITY_ID")
|
||||
acsURL := p.config.GetString("SAML_SP_ACS_URL")
|
||||
|
||||
if spEntityID == "" {
|
||||
return "", errors.NewConfigurationError("SAML_SP_ENTITY_ID not configured")
|
||||
}
|
||||
if acsURL == "" {
|
||||
return "", errors.NewConfigurationError("SAML_SP_ACS_URL not configured")
|
||||
}
|
||||
|
||||
// This is a simplified SP metadata generation
|
||||
// In production, you should use a proper SAML library
|
||||
metadata := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="%s">
|
||||
<md:SPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="%s" index="0"/>
|
||||
</md:SPSSODescriptor>
|
||||
</md:EntityDescriptor>`, spEntityID, acsURL)
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// loadCredentials loads SP private key and certificate
|
||||
func (p *SAMLProvider) loadCredentials() error {
|
||||
// Load private key if configured
|
||||
privateKeyPEM := p.config.GetString("SAML_SP_PRIVATE_KEY")
|
||||
if privateKeyPEM != "" {
|
||||
block, _ := pem.Decode([]byte(privateKeyPEM))
|
||||
if block == nil {
|
||||
return errors.NewConfigurationError("Failed to decode SAML SP private key")
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
// Try PKCS8 format
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return errors.NewConfigurationError("Failed to parse SAML SP private key").WithInternal(err)
|
||||
}
|
||||
var ok bool
|
||||
privateKey, ok = key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return errors.NewConfigurationError("SAML SP private key is not RSA")
|
||||
}
|
||||
}
|
||||
p.privateKey = privateKey
|
||||
}
|
||||
|
||||
// Load certificate if configured
|
||||
certificatePEM := p.config.GetString("SAML_SP_CERTIFICATE")
|
||||
if certificatePEM != "" {
|
||||
block, _ := pem.Decode([]byte(certificatePEM))
|
||||
if block == nil {
|
||||
return errors.NewConfigurationError("Failed to decode SAML SP certificate")
|
||||
}
|
||||
|
||||
certificate, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return errors.NewConfigurationError("Failed to parse SAML SP certificate").WithInternal(err)
|
||||
}
|
||||
p.certificate = certificate
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -132,6 +132,12 @@ func (c *Config) setDefaults() {
|
||||
"IP_BLOCK_DURATION": "1h",
|
||||
"REQUEST_MAX_AGE": "5m",
|
||||
"IP_WHITELIST": "",
|
||||
"SAML_ENABLED": "false",
|
||||
"SAML_IDP_METADATA_URL": "",
|
||||
"SAML_SP_ENTITY_ID": "",
|
||||
"SAML_SP_ACS_URL": "",
|
||||
"SAML_SP_PRIVATE_KEY": "",
|
||||
"SAML_SP_CERTIFICATE": "",
|
||||
}
|
||||
|
||||
for key, value := range defaults {
|
||||
|
||||
@ -188,6 +188,23 @@ type CreateStaticTokenResponse struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// CreateTokenRequest represents a request to create a token
|
||||
type CreateTokenRequest struct {
|
||||
AppID string `json:"app_id" validate:"required"`
|
||||
Type TokenType `json:"type" validate:"required,oneof=static user"`
|
||||
UserID string `json:"user_id,omitempty"` // Required for user tokens
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// CreateTokenResponse represents a response for creating a token
|
||||
type CreateTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
TokenType TokenType `json:"token_type"`
|
||||
}
|
||||
|
||||
// AuthContext represents the authentication context for a request
|
||||
type AuthContext struct {
|
||||
UserID string `json:"user_id"`
|
||||
@ -196,3 +213,25 @@ type AuthContext struct {
|
||||
Claims map[string]string `json:"claims"`
|
||||
AppID string `json:"app_id"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the OAuth2 token response
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo represents user information from the OAuth2/OIDC provider
|
||||
type UserInfo struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Name string `json:"name"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
Picture string `json:"picture"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
}
|
||||
|
||||
153
internal/domain/session.go
Normal file
153
internal/domain/session.go
Normal file
@ -0,0 +1,153 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SessionStatus represents the status of a user session
|
||||
type SessionStatus string
|
||||
|
||||
const (
|
||||
SessionStatusActive SessionStatus = "active"
|
||||
SessionStatusExpired SessionStatus = "expired"
|
||||
SessionStatusRevoked SessionStatus = "revoked"
|
||||
SessionStatusSuspended SessionStatus = "suspended"
|
||||
)
|
||||
|
||||
// SessionType represents the type of session
|
||||
type SessionType string
|
||||
|
||||
const (
|
||||
SessionTypeWeb SessionType = "web"
|
||||
SessionTypeMobile SessionType = "mobile"
|
||||
SessionTypeAPI SessionType = "api"
|
||||
)
|
||||
|
||||
// UserSession represents a user session in the system
|
||||
type UserSession struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
UserID string `json:"user_id" validate:"required" db:"user_id"`
|
||||
AppID string `json:"app_id" validate:"required" db:"app_id"`
|
||||
SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api" db:"session_type"`
|
||||
Status SessionStatus `json:"status" validate:"required,oneof=active expired revoked suspended" db:"status"`
|
||||
AccessToken string `json:"-" db:"access_token"` // Hidden from JSON for security
|
||||
RefreshToken string `json:"-" db:"refresh_token"` // Hidden from JSON for security
|
||||
IDToken string `json:"-" db:"id_token"` // Hidden from JSON for security
|
||||
IPAddress string `json:"ip_address" db:"ip_address"`
|
||||
UserAgent string `json:"user_agent" db:"user_agent"`
|
||||
LastActivity time.Time `json:"last_activity" db:"last_activity"`
|
||||
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
RevokedAt *time.Time `json:"revoked_at,omitempty" db:"revoked_at"`
|
||||
RevokedBy *string `json:"revoked_by,omitempty" db:"revoked_by"`
|
||||
Metadata SessionMetadata `json:"metadata" db:"metadata"`
|
||||
}
|
||||
|
||||
// SessionMetadata contains additional session information
|
||||
type SessionMetadata struct {
|
||||
DeviceInfo string `json:"device_info,omitempty"`
|
||||
Location string `json:"location,omitempty"`
|
||||
LoginMethod string `json:"login_method,omitempty"`
|
||||
TenantID string `json:"tenant_id,omitempty"`
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
Claims map[string]string `json:"claims,omitempty"`
|
||||
RefreshCount int `json:"refresh_count"`
|
||||
LastRefresh *time.Time `json:"last_refresh,omitempty"`
|
||||
}
|
||||
|
||||
// CreateSessionRequest represents a request to create a new session
|
||||
type CreateSessionRequest struct {
|
||||
UserID string `json:"user_id" validate:"required"`
|
||||
AppID string `json:"app_id" validate:"required"`
|
||||
SessionType SessionType `json:"session_type" validate:"required,oneof=web mobile api"`
|
||||
IPAddress string `json:"ip_address" validate:"required,ip"`
|
||||
UserAgent string `json:"user_agent" validate:"required"`
|
||||
ExpiresAt time.Time `json:"expires_at" validate:"required"`
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
Claims map[string]string `json:"claims,omitempty"`
|
||||
TenantID string `json:"tenant_id,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateSessionRequest represents a request to update a session
|
||||
type UpdateSessionRequest struct {
|
||||
Status *SessionStatus `json:"status,omitempty" validate:"omitempty,oneof=active expired revoked suspended"`
|
||||
LastActivity *time.Time `json:"last_activity,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
IPAddress *string `json:"ip_address,omitempty" validate:"omitempty,ip"`
|
||||
UserAgent *string `json:"user_agent,omitempty"`
|
||||
}
|
||||
|
||||
// SessionListRequest represents a request to list sessions
|
||||
type SessionListRequest struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AppID string `json:"app_id,omitempty"`
|
||||
Status *SessionStatus `json:"status,omitempty"`
|
||||
SessionType *SessionType `json:"session_type,omitempty"`
|
||||
TenantID string `json:"tenant_id,omitempty"`
|
||||
Limit int `json:"limit" validate:"min=1,max=100"`
|
||||
Offset int `json:"offset" validate:"min=0"`
|
||||
}
|
||||
|
||||
// SessionListResponse represents a response for listing sessions
|
||||
type SessionListResponse struct {
|
||||
Sessions []*UserSession `json:"sessions"`
|
||||
Total int `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
// IsActive checks if the session is currently active
|
||||
func (s *UserSession) IsActive() bool {
|
||||
return s.Status == SessionStatusActive && time.Now().Before(s.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the session has expired
|
||||
func (s *UserSession) IsExpired() bool {
|
||||
return time.Now().After(s.ExpiresAt) || s.Status == SessionStatusExpired
|
||||
}
|
||||
|
||||
// IsRevoked checks if the session has been revoked
|
||||
func (s *UserSession) IsRevoked() bool {
|
||||
return s.Status == SessionStatusRevoked
|
||||
}
|
||||
|
||||
// CanRefresh checks if the session can be refreshed
|
||||
func (s *UserSession) CanRefresh() bool {
|
||||
return s.IsActive() && s.RefreshToken != ""
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp
|
||||
func (s *UserSession) UpdateActivity() {
|
||||
s.LastActivity = time.Now()
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// Revoke marks the session as revoked
|
||||
func (s *UserSession) Revoke(revokedBy string) {
|
||||
now := time.Now()
|
||||
s.Status = SessionStatusRevoked
|
||||
s.RevokedAt = &now
|
||||
s.RevokedBy = &revokedBy
|
||||
s.UpdatedAt = now
|
||||
}
|
||||
|
||||
// Expire marks the session as expired
|
||||
func (s *UserSession) Expire() {
|
||||
s.Status = SessionStatusExpired
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// Suspend marks the session as suspended
|
||||
func (s *UserSession) Suspend() {
|
||||
s.Status = SessionStatusSuspended
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// Activate marks the session as active
|
||||
func (s *UserSession) Activate() {
|
||||
s.Status = SessionStatusActive
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
307
internal/domain/tenant.go
Normal file
307
internal/domain/tenant.go
Normal file
@ -0,0 +1,307 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TenantStatus represents the status of a tenant
|
||||
type TenantStatus string
|
||||
|
||||
const (
|
||||
TenantStatusActive TenantStatus = "active"
|
||||
TenantStatusSuspended TenantStatus = "suspended"
|
||||
TenantStatusInactive TenantStatus = "inactive"
|
||||
)
|
||||
|
||||
// Tenant represents a tenant in the multi-tenant system
|
||||
type Tenant struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
Name string `json:"name" validate:"required,min=1,max=255" db:"name"`
|
||||
Slug string `json:"slug" validate:"required,min=1,max=100,alphanum" db:"slug"`
|
||||
Status TenantStatus `json:"status" validate:"required,oneof=active suspended inactive" db:"status"`
|
||||
Domain string `json:"domain,omitempty" validate:"omitempty,fqdn" db:"domain"`
|
||||
Description string `json:"description,omitempty" validate:"max=1000" db:"description"`
|
||||
Settings TenantSettings `json:"settings" db:"settings"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
CreatedBy string `json:"created_by" db:"created_by"`
|
||||
UpdatedBy string `json:"updated_by" db:"updated_by"`
|
||||
}
|
||||
|
||||
// TenantSettings contains tenant-specific configuration
|
||||
type TenantSettings struct {
|
||||
// Authentication settings
|
||||
AuthProvider string `json:"auth_provider,omitempty"` // oauth2, saml, header
|
||||
SAMLSettings *SAMLSettings `json:"saml_settings,omitempty"`
|
||||
OAuth2Settings *OAuth2Settings `json:"oauth2_settings,omitempty"`
|
||||
|
||||
// Session settings
|
||||
SessionTimeout time.Duration `json:"session_timeout,omitempty"`
|
||||
MaxConcurrentSessions int `json:"max_concurrent_sessions,omitempty"`
|
||||
|
||||
// Security settings
|
||||
RequireMFA bool `json:"require_mfa"`
|
||||
AllowedIPRanges []string `json:"allowed_ip_ranges,omitempty"`
|
||||
PasswordPolicy *PasswordPolicy `json:"password_policy,omitempty"`
|
||||
|
||||
// Token settings
|
||||
DefaultTokenDuration time.Duration `json:"default_token_duration,omitempty"`
|
||||
MaxTokenDuration time.Duration `json:"max_token_duration,omitempty"`
|
||||
|
||||
// Feature flags
|
||||
Features map[string]bool `json:"features,omitempty"`
|
||||
|
||||
// Custom attributes
|
||||
CustomAttributes map[string]string `json:"custom_attributes,omitempty"`
|
||||
}
|
||||
|
||||
// SAMLSettings contains SAML-specific configuration for a tenant
|
||||
type SAMLSettings struct {
|
||||
IDPMetadataURL string `json:"idp_metadata_url,omitempty"`
|
||||
SPEntityID string `json:"sp_entity_id,omitempty"`
|
||||
ACSURL string `json:"acs_url,omitempty"`
|
||||
SPPrivateKey string `json:"sp_private_key,omitempty"`
|
||||
SPCertificate string `json:"sp_certificate,omitempty"`
|
||||
AttributeMapping map[string]string `json:"attribute_mapping,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2Settings contains OAuth2-specific configuration for a tenant
|
||||
type OAuth2Settings struct {
|
||||
ProviderURL string `json:"provider_url,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
AttributeMapping map[string]string `json:"attribute_mapping,omitempty"`
|
||||
}
|
||||
|
||||
// PasswordPolicy defines password requirements for a tenant
|
||||
type PasswordPolicy struct {
|
||||
MinLength int `json:"min_length"`
|
||||
RequireUppercase bool `json:"require_uppercase"`
|
||||
RequireLowercase bool `json:"require_lowercase"`
|
||||
RequireNumbers bool `json:"require_numbers"`
|
||||
RequireSymbols bool `json:"require_symbols"`
|
||||
MaxAge time.Duration `json:"max_age,omitempty"`
|
||||
PreventReuse int `json:"prevent_reuse"` // Number of previous passwords to prevent reuse
|
||||
}
|
||||
|
||||
// TenantUser represents a user within a specific tenant
|
||||
type TenantUser struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"`
|
||||
UserID string `json:"user_id" validate:"required" db:"user_id"`
|
||||
Email string `json:"email" validate:"required,email" db:"email"`
|
||||
Name string `json:"name" validate:"required" db:"name"`
|
||||
Roles []string `json:"roles" db:"roles"`
|
||||
Permissions []string `json:"permissions" db:"permissions"`
|
||||
Status UserStatus `json:"status" validate:"required,oneof=active inactive suspended" db:"status"`
|
||||
Metadata map[string]string `json:"metadata,omitempty" db:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
LastLoginAt *time.Time `json:"last_login_at,omitempty" db:"last_login_at"`
|
||||
}
|
||||
|
||||
// UserStatus represents the status of a user within a tenant
|
||||
type UserStatus string
|
||||
|
||||
const (
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusInactive UserStatus = "inactive"
|
||||
UserStatusSuspended UserStatus = "suspended"
|
||||
)
|
||||
|
||||
// TenantRole represents a role within a tenant
|
||||
type TenantRole struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
TenantID uuid.UUID `json:"tenant_id" validate:"required" db:"tenant_id"`
|
||||
Name string `json:"name" validate:"required,min=1,max=100" db:"name"`
|
||||
Description string `json:"description,omitempty" validate:"max=500" db:"description"`
|
||||
Permissions []string `json:"permissions" db:"permissions"`
|
||||
IsSystem bool `json:"is_system" db:"is_system"` // System roles cannot be deleted
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
CreatedBy string `json:"created_by" db:"created_by"`
|
||||
UpdatedBy string `json:"updated_by" db:"updated_by"`
|
||||
}
|
||||
|
||||
// CreateTenantRequest represents a request to create a new tenant
|
||||
type CreateTenantRequest struct {
|
||||
Name string `json:"name" validate:"required,min=1,max=255"`
|
||||
Slug string `json:"slug" validate:"required,min=1,max=100,alphanum"`
|
||||
Domain string `json:"domain,omitempty" validate:"omitempty,fqdn"`
|
||||
Description string `json:"description,omitempty" validate:"max=1000"`
|
||||
Settings TenantSettings `json:"settings,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateTenantRequest represents a request to update a tenant
|
||||
type UpdateTenantRequest struct {
|
||||
Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=255"`
|
||||
Status *TenantStatus `json:"status,omitempty" validate:"omitempty,oneof=active suspended inactive"`
|
||||
Domain *string `json:"domain,omitempty" validate:"omitempty,fqdn"`
|
||||
Description *string `json:"description,omitempty" validate:"omitempty,max=1000"`
|
||||
Settings *TenantSettings `json:"settings,omitempty"`
|
||||
}
|
||||
|
||||
// CreateTenantUserRequest represents a request to create a user in a tenant
|
||||
type CreateTenantUserRequest struct {
|
||||
TenantID uuid.UUID `json:"tenant_id" validate:"required"`
|
||||
UserID string `json:"user_id" validate:"required"`
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateTenantUserRequest represents a request to update a tenant user
|
||||
type UpdateTenantUserRequest struct {
|
||||
Email *string `json:"email,omitempty" validate:"omitempty,email"`
|
||||
Name *string `json:"name,omitempty" validate:"omitempty,min=1"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
Status *UserStatus `json:"status,omitempty" validate:"omitempty,oneof=active inactive suspended"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// CreateTenantRoleRequest represents a request to create a role in a tenant
|
||||
type CreateTenantRoleRequest struct {
|
||||
TenantID uuid.UUID `json:"tenant_id" validate:"required"`
|
||||
Name string `json:"name" validate:"required,min=1,max=100"`
|
||||
Description string `json:"description,omitempty" validate:"max=500"`
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateTenantRoleRequest represents a request to update a tenant role
|
||||
type UpdateTenantRoleRequest struct {
|
||||
Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=100"`
|
||||
Description *string `json:"description,omitempty" validate:"omitempty,max=500"`
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
}
|
||||
|
||||
// TenantListRequest represents a request to list tenants
|
||||
type TenantListRequest struct {
|
||||
Status *TenantStatus `json:"status,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Limit int `json:"limit" validate:"min=1,max=100"`
|
||||
Offset int `json:"offset" validate:"min=0"`
|
||||
}
|
||||
|
||||
// TenantListResponse represents a response for listing tenants
|
||||
type TenantListResponse struct {
|
||||
Tenants []*Tenant `json:"tenants"`
|
||||
Total int `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
// IsActive checks if the tenant is active
|
||||
func (t *Tenant) IsActive() bool {
|
||||
return t.Status == TenantStatusActive
|
||||
}
|
||||
|
||||
// IsSuspended checks if the tenant is suspended
|
||||
func (t *Tenant) IsSuspended() bool {
|
||||
return t.Status == TenantStatusSuspended
|
||||
}
|
||||
|
||||
// HasFeature checks if a feature is enabled for the tenant
|
||||
func (t *Tenant) HasFeature(feature string) bool {
|
||||
if t.Settings.Features == nil {
|
||||
return false
|
||||
}
|
||||
enabled, exists := t.Settings.Features[feature]
|
||||
return exists && enabled
|
||||
}
|
||||
|
||||
// GetAuthProvider returns the authentication provider for the tenant
|
||||
func (t *Tenant) GetAuthProvider() string {
|
||||
if t.Settings.AuthProvider != "" {
|
||||
return t.Settings.AuthProvider
|
||||
}
|
||||
return "header" // default
|
||||
}
|
||||
|
||||
// GetSessionTimeout returns the session timeout for the tenant
|
||||
func (t *Tenant) GetSessionTimeout() time.Duration {
|
||||
if t.Settings.SessionTimeout > 0 {
|
||||
return t.Settings.SessionTimeout
|
||||
}
|
||||
return 8 * time.Hour // default
|
||||
}
|
||||
|
||||
// GetMaxConcurrentSessions returns the maximum concurrent sessions for the tenant
|
||||
func (t *Tenant) GetMaxConcurrentSessions() int {
|
||||
if t.Settings.MaxConcurrentSessions > 0 {
|
||||
return t.Settings.MaxConcurrentSessions
|
||||
}
|
||||
return 10 // default
|
||||
}
|
||||
|
||||
// IsActive checks if the tenant user is active
|
||||
func (tu *TenantUser) IsActive() bool {
|
||||
return tu.Status == UserStatusActive
|
||||
}
|
||||
|
||||
// IsSuspended checks if the tenant user is suspended
|
||||
func (tu *TenantUser) IsSuspended() bool {
|
||||
return tu.Status == UserStatusSuspended
|
||||
}
|
||||
|
||||
// HasRole checks if the user has a specific role
|
||||
func (tu *TenantUser) HasRole(role string) bool {
|
||||
for _, r := range tu.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPermission checks if the user has a specific permission
|
||||
func (tu *TenantUser) HasPermission(permission string) bool {
|
||||
for _, p := range tu.Permissions {
|
||||
if p == permission {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateLastLogin updates the last login timestamp
|
||||
func (tu *TenantUser) UpdateLastLogin() {
|
||||
now := time.Now()
|
||||
tu.LastLoginAt = &now
|
||||
tu.UpdatedAt = now
|
||||
}
|
||||
|
||||
// IsSystemRole checks if the role is a system role
|
||||
func (tr *TenantRole) IsSystemRole() bool {
|
||||
return tr.IsSystem
|
||||
}
|
||||
|
||||
// HasPermission checks if the role has a specific permission
|
||||
func (tr *TenantRole) HasPermission(permission string) bool {
|
||||
for _, p := range tr.Permissions {
|
||||
if p == permission {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TenantContext represents the tenant context for a request
|
||||
type TenantContext struct {
|
||||
TenantID uuid.UUID `json:"tenant_id"`
|
||||
TenantSlug string `json:"tenant_slug"`
|
||||
UserID string `json:"user_id"`
|
||||
Roles []string `json:"roles"`
|
||||
Permissions []string `json:"permissions"`
|
||||
}
|
||||
|
||||
// MultiTenantAuthContext extends AuthContext with tenant information
|
||||
type MultiTenantAuthContext struct {
|
||||
*AuthContext
|
||||
TenantContext *TenantContext `json:"tenant_context,omitempty"`
|
||||
}
|
||||
@ -295,3 +295,66 @@ func (c *Chain) Error() string {
|
||||
}
|
||||
return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1)
|
||||
}
|
||||
|
||||
// Helper functions to check error types
|
||||
|
||||
// IsNotFound checks if an error is a not found error
|
||||
func IsNotFound(err error) bool {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return appErr.Code == ErrNotFound || appErr.Code == ErrApplicationNotFound ||
|
||||
appErr.Code == ErrTokenNotFound || appErr.Code == ErrPermissionNotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsValidationError checks if an error is a validation error
|
||||
func IsValidationError(err error) bool {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return appErr.Code == ErrValidationFailed || appErr.Code == ErrInvalidInput ||
|
||||
appErr.Code == ErrMissingField || appErr.Code == ErrInvalidFormat
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAuthenticationError checks if an error is an authentication error
|
||||
func IsAuthenticationError(err error) bool {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return appErr.Code == ErrUnauthorized || appErr.Code == ErrInvalidToken ||
|
||||
appErr.Code == ErrTokenExpired || appErr.Code == ErrInvalidCredentials
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAuthorizationError checks if an error is an authorization error
|
||||
func IsAuthorizationError(err error) bool {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return appErr.Code == ErrForbidden || appErr.Code == ErrInsufficientPermissions
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsConflictError checks if an error is a conflict error
|
||||
func IsConflictError(err error) bool {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return appErr.Code == ErrAlreadyExists || appErr.Code == ErrConflict
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsInternalError checks if an error is an internal server error
|
||||
func IsInternalError(err error) bool {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
return appErr.Code == ErrInternal || appErr.Code == ErrDatabase ||
|
||||
appErr.Code == ErrExternal || appErr.Code == ErrTokenCreationFailed
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsConfigurationError checks if an error is a configuration error
|
||||
func IsConfigurationError(err error) bool {
|
||||
if appErr, ok := err.(*AppError); ok {
|
||||
// Configuration errors are typically mapped to internal errors
|
||||
return appErr.Code == ErrInternal && appErr.Message != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
352
internal/handlers/saml.go
Normal file
352
internal/handlers/saml.go
Normal file
@ -0,0 +1,352 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
)
|
||||
|
||||
// SAMLHandler handles SAML authentication endpoints
|
||||
type SAMLHandler struct {
|
||||
samlProvider *auth.SAMLProvider
|
||||
sessionService services.SessionService
|
||||
authService services.AuthenticationService
|
||||
tokenService services.TokenService
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewSAMLHandler creates a new SAML handler
|
||||
func NewSAMLHandler(
|
||||
config config.ConfigProvider,
|
||||
sessionService services.SessionService,
|
||||
authService services.AuthenticationService,
|
||||
tokenService services.TokenService,
|
||||
logger *zap.Logger,
|
||||
) (*SAMLHandler, error) {
|
||||
samlProvider, err := auth.NewSAMLProvider(config, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SAMLHandler{
|
||||
samlProvider: samlProvider,
|
||||
sessionService: sessionService,
|
||||
authService: authService,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RegisterRoutes registers SAML routes
|
||||
func (h *SAMLHandler) RegisterRoutes(router *mux.Router) {
|
||||
// SAML endpoints
|
||||
router.HandleFunc("/auth/saml/login", h.InitiateSAMLLogin).Methods("GET")
|
||||
router.HandleFunc("/auth/saml/acs", h.HandleSAMLResponse).Methods("POST")
|
||||
router.HandleFunc("/auth/saml/metadata", h.GetServiceProviderMetadata).Methods("GET")
|
||||
router.HandleFunc("/auth/saml/slo", h.HandleSingleLogout).Methods("GET", "POST")
|
||||
}
|
||||
|
||||
// InitiateSAMLLogin initiates SAML authentication
|
||||
func (h *SAMLHandler) InitiateSAMLLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.config.GetBool("SAML_ENABLED") {
|
||||
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
||||
return
|
||||
}
|
||||
|
||||
// Get query parameters
|
||||
appID := r.URL.Query().Get("app_id")
|
||||
redirectURL := r.URL.Query().Get("redirect_url")
|
||||
|
||||
if appID == "" {
|
||||
h.writeErrorResponse(w, errors.NewValidationError("app_id parameter is required"))
|
||||
return
|
||||
}
|
||||
|
||||
// Generate relay state with app_id and redirect_url
|
||||
relayState := appID
|
||||
if redirectURL != "" {
|
||||
relayState += "|" + redirectURL
|
||||
}
|
||||
|
||||
h.logger.Debug("Initiating SAML login",
|
||||
zap.String("app_id", appID),
|
||||
zap.String("redirect_url", redirectURL))
|
||||
|
||||
// Generate SAML authentication request
|
||||
authURL, requestID, err := h.samlProvider.GenerateAuthRequest(r.Context(), relayState)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate SAML auth request", zap.Error(err))
|
||||
h.writeErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Store request ID in session/cache for validation
|
||||
// In production, you should store this securely
|
||||
h.logger.Debug("Generated SAML auth request",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("auth_url", authURL))
|
||||
|
||||
// Redirect to IdP
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// HandleSAMLResponse handles SAML assertion consumer service (ACS)
|
||||
func (h *SAMLHandler) HandleSAMLResponse(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.config.GetBool("SAML_ENABLED") {
|
||||
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Handling SAML response")
|
||||
|
||||
// Parse form data
|
||||
if err := r.ParseForm(); err != nil {
|
||||
h.writeErrorResponse(w, errors.NewValidationError("Failed to parse form data").WithInternal(err))
|
||||
return
|
||||
}
|
||||
|
||||
samlResponse := r.FormValue("SAMLResponse")
|
||||
relayState := r.FormValue("RelayState")
|
||||
|
||||
if samlResponse == "" {
|
||||
h.writeErrorResponse(w, errors.NewValidationError("SAMLResponse is required"))
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Processing SAML response", zap.String("relay_state", relayState))
|
||||
|
||||
// Process SAML response
|
||||
// In production, you should retrieve and validate the original request ID
|
||||
authContext, err := h.samlProvider.ProcessSAMLResponse(r.Context(), samlResponse, "")
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to process SAML response", zap.Error(err))
|
||||
h.writeErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse relay state to get app_id and redirect_url
|
||||
appID, redirectURL := h.parseRelayState(relayState)
|
||||
if appID == "" {
|
||||
h.writeErrorResponse(w, errors.NewValidationError("Invalid relay state: missing app_id"))
|
||||
return
|
||||
}
|
||||
|
||||
// Create user session
|
||||
sessionReq := &domain.CreateSessionRequest{
|
||||
UserID: authContext.UserID,
|
||||
AppID: appID,
|
||||
SessionType: domain.SessionTypeWeb,
|
||||
IPAddress: h.getClientIP(r),
|
||||
UserAgent: r.UserAgent(),
|
||||
ExpiresAt: time.Now().Add(8 * time.Hour), // 8 hour session
|
||||
Permissions: authContext.Permissions,
|
||||
Claims: authContext.Claims,
|
||||
}
|
||||
|
||||
session, err := h.sessionService.CreateSession(r.Context(), sessionReq)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create session", zap.Error(err))
|
||||
h.writeErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate JWT token for the session using the existing token service
|
||||
userToken := &domain.UserToken{
|
||||
AppID: appID,
|
||||
UserID: authContext.UserID,
|
||||
Permissions: authContext.Permissions,
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: session.ExpiresAt,
|
||||
MaxValidAt: session.ExpiresAt,
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: authContext.Claims,
|
||||
}
|
||||
|
||||
tokenString, err := h.authService.GenerateJWTToken(r.Context(), userToken)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create JWT token", zap.Error(err))
|
||||
h.writeErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("SAML authentication successful",
|
||||
zap.String("user_id", authContext.UserID),
|
||||
zap.String("session_id", session.ID.String()))
|
||||
|
||||
// If redirect URL is provided, redirect with token
|
||||
if redirectURL != "" {
|
||||
// Add token as query parameter or fragment
|
||||
redirectURL += "?token=" + tokenString
|
||||
http.Redirect(w, r, redirectURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, return JSON response
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"token": tokenString,
|
||||
"user": map[string]interface{}{
|
||||
"id": authContext.UserID,
|
||||
"email": authContext.Claims["email"],
|
||||
"name": authContext.Claims["name"],
|
||||
},
|
||||
"session_id": session.ID.String(),
|
||||
"expires_at": session.ExpiresAt,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// GetServiceProviderMetadata returns SP metadata XML
|
||||
func (h *SAMLHandler) GetServiceProviderMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.config.GetBool("SAML_ENABLED") {
|
||||
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Generating SP metadata")
|
||||
|
||||
metadata, err := h.samlProvider.GenerateServiceProviderMetadata()
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate SP metadata", zap.Error(err))
|
||||
h.writeErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.Write([]byte(metadata))
|
||||
}
|
||||
|
||||
// HandleSingleLogout handles SAML single logout
|
||||
func (h *SAMLHandler) HandleSingleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.config.GetBool("SAML_ENABLED") {
|
||||
h.writeErrorResponse(w, errors.NewConfigurationError("SAML authentication is not enabled"))
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Handling SAML single logout")
|
||||
|
||||
// Get session ID from query parameter or form
|
||||
sessionID := r.URL.Query().Get("session_id")
|
||||
if sessionID == "" && r.Method == "POST" {
|
||||
r.ParseForm()
|
||||
sessionID = r.FormValue("session_id")
|
||||
}
|
||||
|
||||
if sessionID != "" {
|
||||
// Revoke specific session
|
||||
h.logger.Debug("Revoking session", zap.String("session_id", sessionID))
|
||||
// Implementation would depend on how you store session IDs
|
||||
// For now, we'll just log it
|
||||
}
|
||||
|
||||
// In a full implementation, you would:
|
||||
// 1. Parse the SAML LogoutRequest
|
||||
// 2. Validate the request
|
||||
// 3. Revoke the user's sessions
|
||||
// 4. Generate a LogoutResponse
|
||||
// 5. Redirect back to the IdP
|
||||
|
||||
// For now, return a simple success response
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Logout successful",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// parseRelayState parses the relay state to extract app_id and redirect_url
|
||||
func (h *SAMLHandler) parseRelayState(relayState string) (appID, redirectURL string) {
|
||||
if relayState == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// RelayState format: "app_id|redirect_url" or just "app_id"
|
||||
parts := []string{relayState}
|
||||
if len(relayState) > 0 && relayState[0] != '|' {
|
||||
// Split on first pipe character
|
||||
for i, char := range relayState {
|
||||
if char == '|' {
|
||||
parts = []string{relayState[:i], relayState[i+1:]}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
appID = parts[0]
|
||||
if len(parts) > 1 {
|
||||
redirectURL = parts[1]
|
||||
}
|
||||
|
||||
return appID, redirectURL
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP address from the request
|
||||
func (h *SAMLHandler) getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header first
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP if multiple are present
|
||||
if idx := len(xff); idx > 0 {
|
||||
for i, char := range xff {
|
||||
if char == ',' {
|
||||
return xff[:i]
|
||||
}
|
||||
}
|
||||
return xff
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// writeErrorResponse writes an error response
|
||||
func (h *SAMLHandler) writeErrorResponse(w http.ResponseWriter, err error) {
|
||||
var statusCode int
|
||||
var errorCode string
|
||||
|
||||
switch {
|
||||
case errors.IsValidationError(err):
|
||||
statusCode = http.StatusBadRequest
|
||||
errorCode = "VALIDATION_ERROR"
|
||||
case errors.IsAuthenticationError(err):
|
||||
statusCode = http.StatusUnauthorized
|
||||
errorCode = "AUTHENTICATION_ERROR"
|
||||
case errors.IsConfigurationError(err):
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
errorCode = "CONFIGURATION_ERROR"
|
||||
default:
|
||||
statusCode = http.StatusInternalServerError
|
||||
errorCode = "INTERNAL_ERROR"
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": false,
|
||||
"error": map[string]interface{}{
|
||||
"code": errorCode,
|
||||
"message": err.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
@ -2,7 +2,6 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -14,7 +13,6 @@ import (
|
||||
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// SecurityMiddleware provides various security features
|
||||
@ -408,8 +406,6 @@ func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
|
||||
|
||||
// GetSecurityMetrics returns security-related metrics
|
||||
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
|
||||
ctx := context.Background()
|
||||
|
||||
// This is a simplified version - in production you'd want more comprehensive metrics
|
||||
metrics := map[string]interface{}{
|
||||
"active_rate_limiters": len(s.rateLimiters),
|
||||
|
||||
@ -104,6 +104,57 @@ type GrantedPermissionRepository interface {
|
||||
HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error)
|
||||
}
|
||||
|
||||
// SessionRepository defines the interface for user session data operations
|
||||
type SessionRepository interface {
|
||||
// Create creates a new user session
|
||||
Create(ctx context.Context, session *domain.UserSession) error
|
||||
|
||||
// GetByID retrieves a session by its ID
|
||||
GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
|
||||
|
||||
// GetByUserID retrieves all sessions for a user
|
||||
GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetByUserAndApp retrieves sessions for a specific user and application
|
||||
GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetActiveByUserID retrieves all active sessions for a user
|
||||
GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// List retrieves sessions with filtering and pagination
|
||||
List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
|
||||
|
||||
// Update updates an existing session
|
||||
Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a session
|
||||
UpdateActivity(ctx context.Context, sessionID uuid.UUID) error
|
||||
|
||||
// Revoke revokes a session
|
||||
Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
|
||||
|
||||
// RevokeAllByUser revokes all sessions for a user
|
||||
RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error
|
||||
|
||||
// RevokeAllByUserAndApp revokes all sessions for a user and application
|
||||
RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error
|
||||
|
||||
// ExpireOldSessions marks expired sessions as expired
|
||||
ExpireOldSessions(ctx context.Context) (int, error)
|
||||
|
||||
// DeleteExpiredSessions removes expired sessions older than the specified duration
|
||||
DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error)
|
||||
|
||||
// Exists checks if a session exists
|
||||
Exists(ctx context.Context, sessionID uuid.UUID) (bool, error)
|
||||
|
||||
// GetSessionCount returns the total number of sessions for a user
|
||||
GetSessionCount(ctx context.Context, userID string) (int, error)
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions for a user
|
||||
GetActiveSessionCount(ctx context.Context, userID string) (int, error)
|
||||
}
|
||||
|
||||
// DatabaseProvider defines the interface for database operations
|
||||
type DatabaseProvider interface {
|
||||
// GetDB returns the underlying database connection
|
||||
|
||||
624
internal/repository/postgres/session_repository.go
Normal file
624
internal/repository/postgres/session_repository.go
Normal file
@ -0,0 +1,624 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// sessionRepository implements the SessionRepository interface
|
||||
type sessionRepository struct {
|
||||
db *sqlx.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewSessionRepository creates a new session repository
|
||||
func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository {
|
||||
return &sessionRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new user session
|
||||
func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error {
|
||||
r.logger.Debug("Creating new session",
|
||||
zap.String("user_id", session.UserID),
|
||||
zap.String("app_id", session.AppID),
|
||||
zap.String("session_type", string(session.SessionType)))
|
||||
|
||||
// Generate ID if not provided
|
||||
if session.ID == uuid.Nil {
|
||||
session.ID = uuid.New()
|
||||
}
|
||||
|
||||
// Set timestamps
|
||||
now := time.Now()
|
||||
session.CreatedAt = now
|
||||
session.UpdatedAt = now
|
||||
session.LastActivity = now
|
||||
|
||||
// Serialize metadata
|
||||
metadataJSON, err := json.Marshal(session.Metadata)
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO user_sessions (
|
||||
id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at, metadata
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
|
||||
)`
|
||||
|
||||
_, err = r.db.ExecContext(ctx, query,
|
||||
session.ID,
|
||||
session.UserID,
|
||||
session.AppID,
|
||||
session.SessionType,
|
||||
session.Status,
|
||||
session.AccessToken,
|
||||
session.RefreshToken,
|
||||
session.IDToken,
|
||||
session.IPAddress,
|
||||
session.UserAgent,
|
||||
session.LastActivity,
|
||||
session.ExpiresAt,
|
||||
session.CreatedAt,
|
||||
session.UpdatedAt,
|
||||
metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to create session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to create session").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a session by its ID
|
||||
func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE id = $1`
|
||||
|
||||
var session domain.UserSession
|
||||
var metadataJSON []byte
|
||||
var revokedAt sql.NullTime
|
||||
var revokedBy sql.NullString
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.AppID,
|
||||
&session.SessionType,
|
||||
&session.Status,
|
||||
&session.AccessToken,
|
||||
&session.RefreshToken,
|
||||
&session.IDToken,
|
||||
&session.IPAddress,
|
||||
&session.UserAgent,
|
||||
&session.LastActivity,
|
||||
&session.ExpiresAt,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&revokedAt,
|
||||
&revokedBy,
|
||||
&metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
r.logger.Error("Failed to get session by ID", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if revokedAt.Valid {
|
||||
session.RevokedAt = &revokedAt.Time
|
||||
}
|
||||
if revokedBy.Valid {
|
||||
session.RevokedBy = &revokedBy.String
|
||||
}
|
||||
|
||||
// Deserialize metadata
|
||||
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
|
||||
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
|
||||
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
|
||||
}
|
||||
|
||||
r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String()))
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetByUserID retrieves all sessions for a user
|
||||
func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID)
|
||||
}
|
||||
|
||||
// GetByUserAndApp retrieves sessions for a specific user and application
|
||||
func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting sessions by user and app",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND app_id = $2
|
||||
ORDER BY created_at DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID, appID)
|
||||
}
|
||||
|
||||
// GetActiveByUserID retrieves all active sessions for a user
|
||||
func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND status = $2 AND expires_at > NOW()
|
||||
ORDER BY last_activity DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID, domain.SessionStatusActive)
|
||||
}
|
||||
|
||||
// List retrieves sessions with filtering and pagination
|
||||
func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
|
||||
r.logger.Debug("Listing sessions with filters",
|
||||
zap.String("user_id", req.UserID),
|
||||
zap.String("app_id", req.AppID),
|
||||
zap.Int("limit", req.Limit),
|
||||
zap.Int("offset", req.Offset))
|
||||
|
||||
// Build WHERE clause dynamically
|
||||
whereClause := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if req.UserID != "" {
|
||||
whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex)
|
||||
args = append(args, req.UserID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.AppID != "" {
|
||||
whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex)
|
||||
args = append(args, req.AppID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
whereClause += fmt.Sprintf(" AND status = $%d", argIndex)
|
||||
args = append(args, *req.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.SessionType != nil {
|
||||
whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex)
|
||||
args = append(args, *req.SessionType)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.TenantID != "" {
|
||||
whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex)
|
||||
args = append(args, req.TenantID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause)
|
||||
var total int
|
||||
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get session count", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
// Get sessions with pagination
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1)
|
||||
|
||||
args = append(args, req.Limit, req.Offset)
|
||||
sessions, err := r.scanSessions(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.SessionListResponse{
|
||||
Sessions: sessions,
|
||||
Total: total,
|
||||
Limit: req.Limit,
|
||||
Offset: req.Offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Update updates an existing session
|
||||
func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
|
||||
r.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
|
||||
|
||||
// Build UPDATE clause dynamically
|
||||
setParts := []string{"updated_at = NOW()"}
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if updates.Status != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex))
|
||||
args = append(args, *updates.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.LastActivity != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex))
|
||||
args = append(args, *updates.LastActivity)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.ExpiresAt != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex))
|
||||
args = append(args, *updates.ExpiresAt)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.IPAddress != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex))
|
||||
args = append(args, *updates.IPAddress)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.UserAgent != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex))
|
||||
args = append(args, *updates.UserAgent)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(setParts) == 1 {
|
||||
return errors.NewValidationError("No fields to update")
|
||||
}
|
||||
|
||||
// Build the complete query
|
||||
setClause := fmt.Sprintf("%s", setParts[0])
|
||||
for i := 1; i < len(setParts); i++ {
|
||||
setClause += fmt.Sprintf(", %s", setParts[i])
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex)
|
||||
args = append(args, sessionID)
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to update session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to update session").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a session
|
||||
func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error {
|
||||
r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, sessionID)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to update session activity", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to update session activity").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke revokes a session
|
||||
func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
|
||||
r.logger.Debug("Revoking session",
|
||||
zap.String("session_id", sessionID.String()),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE id = $3`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke session").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllByUser revokes all sessions for a user
|
||||
func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error {
|
||||
r.logger.Debug("Revoking all sessions for user",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE user_id = $3 AND status = $4`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke user sessions", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("User sessions revoked",
|
||||
zap.String("user_id", userID),
|
||||
zap.Int64("sessions_revoked", rowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllByUserAndApp revokes all sessions for a user and application
|
||||
func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error {
|
||||
r.logger.Debug("Revoking all sessions for user and app",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE user_id = $3 AND app_id = $4 AND status = $5`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke user app sessions", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("User app sessions revoked",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.Int64("sessions_revoked", rowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpireOldSessions marks expired sessions as expired
|
||||
func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) {
|
||||
r.logger.Debug("Expiring old sessions")
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, updated_at = NOW()
|
||||
WHERE expires_at < NOW() AND status = $2`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to expire old sessions", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected))
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// DeleteExpiredSessions removes expired sessions older than the specified duration
|
||||
func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) {
|
||||
r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan))
|
||||
|
||||
cutoffTime := time.Now().Add(-olderThan)
|
||||
query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete expired sessions", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected))
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// Exists checks if a session exists
|
||||
func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) {
|
||||
r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)`
|
||||
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to check session existence", zap.Error(err))
|
||||
return false, errors.NewInternalError("Failed to check session existence").WithInternal(err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetSessionCount returns the total number of sessions for a user
|
||||
func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) {
|
||||
r.logger.Debug("Getting session count for user", zap.String("user_id", userID))
|
||||
|
||||
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1`
|
||||
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get session count", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to get session count").WithInternal(err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions for a user
|
||||
func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) {
|
||||
r.logger.Debug("Getting active session count for user", zap.String("user_id", userID))
|
||||
|
||||
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()`
|
||||
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get active session count", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// scanSessions is a helper method to scan multiple sessions from query results
|
||||
func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to execute session query", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []*domain.UserSession
|
||||
for rows.Next() {
|
||||
var session domain.UserSession
|
||||
var metadataJSON []byte
|
||||
var revokedAt sql.NullTime
|
||||
var revokedBy sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.AppID,
|
||||
&session.SessionType,
|
||||
&session.Status,
|
||||
&session.AccessToken,
|
||||
&session.RefreshToken,
|
||||
&session.IDToken,
|
||||
&session.IPAddress,
|
||||
&session.UserAgent,
|
||||
&session.LastActivity,
|
||||
&session.ExpiresAt,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&revokedAt,
|
||||
&revokedBy,
|
||||
&metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to scan session row", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if revokedAt.Valid {
|
||||
session.RevokedAt = &revokedAt.Time
|
||||
}
|
||||
if revokedBy.Valid {
|
||||
session.RevokedBy = &revokedBy.String
|
||||
}
|
||||
|
||||
// Deserialize metadata
|
||||
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
|
||||
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
|
||||
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
|
||||
}
|
||||
|
||||
sessions = append(sessions, &session)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
r.logger.Error("Error iterating session rows", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
@ -67,3 +67,54 @@ type AuthenticationService interface {
|
||||
// RefreshJWTToken refreshes an existing JWT token
|
||||
RefreshJWTToken(ctx context.Context, tokenString string, newExpiration time.Time) (string, error)
|
||||
}
|
||||
|
||||
// SessionService defines the interface for session management business logic
|
||||
type SessionService interface {
|
||||
// CreateSession creates a new user session
|
||||
CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error)
|
||||
|
||||
// GetSession retrieves a session by its ID
|
||||
GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
|
||||
|
||||
// GetUserSessions retrieves all sessions for a user
|
||||
GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetUserAppSessions retrieves sessions for a specific user and application
|
||||
GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetActiveSessions retrieves all active sessions for a user
|
||||
GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// ListSessions retrieves sessions with filtering and pagination
|
||||
ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
|
||||
|
||||
// UpdateSession updates an existing session
|
||||
UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
|
||||
|
||||
// UpdateSessionActivity updates the last activity timestamp for a session
|
||||
UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error
|
||||
|
||||
// RevokeSession revokes a specific session
|
||||
RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
|
||||
|
||||
// RevokeUserSessions revokes all sessions for a user
|
||||
RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error
|
||||
|
||||
// RevokeUserAppSessions revokes all sessions for a user and application
|
||||
RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error
|
||||
|
||||
// ValidateSession validates if a session is active and valid
|
||||
ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
|
||||
|
||||
// RefreshSession refreshes a session's expiration time
|
||||
RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error
|
||||
|
||||
// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
|
||||
CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error)
|
||||
|
||||
// GetSessionStats returns session statistics for a user
|
||||
GetSessionStats(ctx context.Context, userID string) (total int, active int, err error)
|
||||
|
||||
// CreateOAuth2Session creates a session from OAuth2 authentication flow
|
||||
CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error)
|
||||
}
|
||||
|
||||
414
internal/services/session_service.go
Normal file
414
internal/services/session_service.go
Normal file
@ -0,0 +1,414 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// sessionService implements the SessionService interface
|
||||
type sessionService struct {
|
||||
sessionRepo repository.SessionRepository
|
||||
appRepo repository.ApplicationRepository
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewSessionService creates a new session service
|
||||
func NewSessionService(
|
||||
sessionRepo repository.SessionRepository,
|
||||
appRepo repository.ApplicationRepository,
|
||||
config config.ConfigProvider,
|
||||
logger *zap.Logger,
|
||||
) SessionService {
|
||||
return &sessionService{
|
||||
sessionRepo: sessionRepo,
|
||||
appRepo: appRepo,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new user session
|
||||
func (s *sessionService) CreateSession(ctx context.Context, req *domain.CreateSessionRequest) (*domain.UserSession, error) {
|
||||
s.logger.Debug("Creating new session",
|
||||
zap.String("user_id", req.UserID),
|
||||
zap.String("app_id", req.AppID),
|
||||
zap.String("session_type", string(req.SessionType)))
|
||||
|
||||
// Validate application exists
|
||||
app, err := s.appRepo.GetByID(ctx, req.AppID)
|
||||
if err != nil {
|
||||
if errors.IsNotFound(err) {
|
||||
return nil, errors.NewValidationError("Application not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if application supports user tokens
|
||||
supportsUser := false
|
||||
for _, appType := range app.Type {
|
||||
if appType == domain.ApplicationTypeUser {
|
||||
supportsUser = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !supportsUser {
|
||||
return nil, errors.NewValidationError("Application does not support user sessions")
|
||||
}
|
||||
|
||||
// Create session object
|
||||
session := &domain.UserSession{
|
||||
ID: uuid.New(),
|
||||
UserID: req.UserID,
|
||||
AppID: req.AppID,
|
||||
SessionType: req.SessionType,
|
||||
Status: domain.SessionStatusActive,
|
||||
IPAddress: req.IPAddress,
|
||||
UserAgent: req.UserAgent,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
Metadata: domain.SessionMetadata{
|
||||
TenantID: req.TenantID,
|
||||
Permissions: req.Permissions,
|
||||
Claims: req.Claims,
|
||||
LoginMethod: "oauth2",
|
||||
},
|
||||
}
|
||||
|
||||
// Create session in repository
|
||||
if err := s.sessionRepo.Create(ctx, session); err != nil {
|
||||
s.logger.Error("Failed to create session", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by its ID
|
||||
func (s *sessionService) GetSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
|
||||
s.logger.Debug("Getting session", zap.String("session_id", sessionID.String()))
|
||||
|
||||
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// GetUserSessions retrieves all sessions for a user
|
||||
func (s *sessionService) GetUserSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
s.logger.Debug("Getting user sessions", zap.String("user_id", userID))
|
||||
|
||||
sessions, err := s.sessionRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// GetUserAppSessions retrieves sessions for a specific user and application
|
||||
func (s *sessionService) GetUserAppSessions(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
|
||||
s.logger.Debug("Getting user app sessions",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID))
|
||||
|
||||
sessions, err := s.sessionRepo.GetByUserAndApp(ctx, userID, appID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// GetActiveSessions retrieves all active sessions for a user
|
||||
func (s *sessionService) GetActiveSessions(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
s.logger.Debug("Getting active sessions", zap.String("user_id", userID))
|
||||
|
||||
sessions, err := s.sessionRepo.GetActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// ListSessions retrieves sessions with filtering and pagination
|
||||
func (s *sessionService) ListSessions(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
|
||||
s.logger.Debug("Listing sessions",
|
||||
zap.String("user_id", req.UserID),
|
||||
zap.String("app_id", req.AppID),
|
||||
zap.Int("limit", req.Limit),
|
||||
zap.Int("offset", req.Offset))
|
||||
|
||||
// Set default pagination if not provided
|
||||
if req.Limit <= 0 {
|
||||
req.Limit = 50
|
||||
}
|
||||
if req.Limit > 100 {
|
||||
req.Limit = 100
|
||||
}
|
||||
|
||||
response, err := s.sessionRepo.List(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// UpdateSession updates an existing session
|
||||
func (s *sessionService) UpdateSession(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
|
||||
s.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
|
||||
|
||||
// Validate session exists
|
||||
_, err := s.sessionRepo.GetByID(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update session
|
||||
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSessionActivity updates the last activity timestamp for a session
|
||||
func (s *sessionService) UpdateSessionActivity(ctx context.Context, sessionID uuid.UUID) error {
|
||||
s.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
|
||||
|
||||
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeSession revokes a specific session
|
||||
func (s *sessionService) RevokeSession(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
|
||||
s.logger.Debug("Revoking session",
|
||||
zap.String("session_id", sessionID.String()),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
// Validate session exists and is active
|
||||
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if session.Status != domain.SessionStatusActive {
|
||||
return errors.NewValidationError("Session is not active")
|
||||
}
|
||||
|
||||
// Revoke session
|
||||
if err := s.sessionRepo.Revoke(ctx, sessionID, revokedBy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUserSessions revokes all sessions for a user
|
||||
func (s *sessionService) RevokeUserSessions(ctx context.Context, userID string, revokedBy string) error {
|
||||
s.logger.Debug("Revoking user sessions",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
if err := s.sessionRepo.RevokeAllByUser(ctx, userID, revokedBy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Debug("User sessions revoked successfully", zap.String("user_id", userID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUserAppSessions revokes all sessions for a user and application
|
||||
func (s *sessionService) RevokeUserAppSessions(ctx context.Context, userID, appID string, revokedBy string) error {
|
||||
s.logger.Debug("Revoking user app sessions",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
if err := s.sessionRepo.RevokeAllByUserAndApp(ctx, userID, appID, revokedBy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Debug("User app sessions revoked successfully",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSession validates if a session is active and valid
|
||||
func (s *sessionService) ValidateSession(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
|
||||
s.logger.Debug("Validating session", zap.String("session_id", sessionID.String()))
|
||||
|
||||
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if session is active
|
||||
if !session.IsActive() {
|
||||
if session.IsExpired() {
|
||||
return nil, errors.NewAuthenticationError("Session has expired")
|
||||
}
|
||||
if session.IsRevoked() {
|
||||
return nil, errors.NewAuthenticationError("Session has been revoked")
|
||||
}
|
||||
return nil, errors.NewAuthenticationError("Session is not active")
|
||||
}
|
||||
|
||||
// Update last activity
|
||||
if err := s.sessionRepo.UpdateActivity(ctx, sessionID); err != nil {
|
||||
s.logger.Warn("Failed to update session activity", zap.Error(err))
|
||||
// Don't fail validation if we can't update activity
|
||||
}
|
||||
|
||||
s.logger.Debug("Session validated successfully", zap.String("session_id", sessionID.String()))
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// RefreshSession refreshes a session's expiration time
|
||||
func (s *sessionService) RefreshSession(ctx context.Context, sessionID uuid.UUID, newExpiration time.Time) error {
|
||||
s.logger.Debug("Refreshing session",
|
||||
zap.String("session_id", sessionID.String()),
|
||||
zap.Time("new_expiration", newExpiration))
|
||||
|
||||
// Validate session exists and is active
|
||||
session, err := s.sessionRepo.GetByID(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !session.IsActive() {
|
||||
return errors.NewValidationError("Cannot refresh inactive session")
|
||||
}
|
||||
|
||||
// Update expiration
|
||||
updates := &domain.UpdateSessionRequest{
|
||||
ExpiresAt: &newExpiration,
|
||||
}
|
||||
|
||||
if err := s.sessionRepo.Update(ctx, sessionID, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Debug("Session refreshed successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions marks expired sessions as expired and optionally deletes old ones
|
||||
func (s *sessionService) CleanupExpiredSessions(ctx context.Context, deleteOlderThan *time.Duration) (expired int, deleted int, err error) {
|
||||
s.logger.Debug("Cleaning up expired sessions")
|
||||
|
||||
// Mark expired sessions
|
||||
expired, err = s.sessionRepo.ExpireOldSessions(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to expire old sessions", zap.Error(err))
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Delete old expired sessions if requested
|
||||
if deleteOlderThan != nil {
|
||||
deleted, err = s.sessionRepo.DeleteExpiredSessions(ctx, *deleteOlderThan)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete expired sessions", zap.Error(err))
|
||||
return expired, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Debug("Session cleanup completed",
|
||||
zap.Int("expired", expired),
|
||||
zap.Int("deleted", deleted))
|
||||
|
||||
return expired, deleted, nil
|
||||
}
|
||||
|
||||
// GetSessionStats returns session statistics for a user
|
||||
func (s *sessionService) GetSessionStats(ctx context.Context, userID string) (total int, active int, err error) {
|
||||
s.logger.Debug("Getting session stats", zap.String("user_id", userID))
|
||||
|
||||
total, err = s.sessionRepo.GetSessionCount(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
active, err = s.sessionRepo.GetActiveSessionCount(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return total, active, nil
|
||||
}
|
||||
|
||||
// CreateOAuth2Session creates a session from OAuth2 authentication flow
|
||||
func (s *sessionService) CreateOAuth2Session(ctx context.Context, userID, appID string, tokenResponse *domain.TokenResponse, userInfo *domain.UserInfo, sessionType domain.SessionType, ipAddress, userAgent string) (*domain.UserSession, error) {
|
||||
s.logger.Debug("Creating OAuth2 session",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.String("session_type", string(sessionType)))
|
||||
|
||||
// Validate application exists
|
||||
app, err := s.appRepo.GetByID(ctx, appID)
|
||||
if err != nil {
|
||||
if errors.IsNotFound(err) {
|
||||
return nil, errors.NewValidationError("Application not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate expiration based on token response
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second)
|
||||
|
||||
// Use application's max token duration if shorter
|
||||
maxExpiration := time.Now().Add(app.MaxTokenDuration)
|
||||
if expiresAt.After(maxExpiration) {
|
||||
expiresAt = maxExpiration
|
||||
}
|
||||
|
||||
// Create session object
|
||||
session := &domain.UserSession{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
AppID: appID,
|
||||
SessionType: sessionType,
|
||||
Status: domain.SessionStatusActive,
|
||||
AccessToken: tokenResponse.AccessToken, // In production, encrypt this
|
||||
RefreshToken: tokenResponse.RefreshToken, // In production, encrypt this
|
||||
IDToken: tokenResponse.IDToken, // In production, encrypt this
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
ExpiresAt: expiresAt,
|
||||
Metadata: domain.SessionMetadata{
|
||||
LoginMethod: "oauth2",
|
||||
Claims: map[string]string{
|
||||
"sub": userInfo.Sub,
|
||||
"email": userInfo.Email,
|
||||
"name": userInfo.Name,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create session in repository
|
||||
if err := s.sessionRepo.Create(ctx, session); err != nil {
|
||||
s.logger.Error("Failed to create OAuth2 session", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.logger.Debug("OAuth2 session created successfully", zap.String("session_id", session.ID.String()))
|
||||
return session, nil
|
||||
}
|
||||
14
migrations/002_user_sessions.down.sql
Normal file
14
migrations/002_user_sessions.down.sql
Normal file
@ -0,0 +1,14 @@
|
||||
-- Drop user_sessions table and related objects
|
||||
DROP INDEX IF EXISTS idx_user_sessions_tenant;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_metadata;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_active_expires;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_active;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_created_at;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_last_activity;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_expires_at;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_status;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_user_app;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_app_id;
|
||||
DROP INDEX IF EXISTS idx_user_sessions_user_id;
|
||||
|
||||
DROP TABLE IF EXISTS user_sessions;
|
||||
60
migrations/002_user_sessions.up.sql
Normal file
60
migrations/002_user_sessions.up.sql
Normal file
@ -0,0 +1,60 @@
|
||||
-- Create user_sessions table for session management
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
app_id VARCHAR(255) NOT NULL,
|
||||
session_type VARCHAR(20) NOT NULL CHECK (session_type IN ('web', 'mobile', 'api')),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'expired', 'revoked', 'suspended')),
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
id_token TEXT,
|
||||
ip_address INET,
|
||||
user_agent TEXT,
|
||||
last_activity TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
revoked_at TIMESTAMP WITH TIME ZONE,
|
||||
revoked_by VARCHAR(255),
|
||||
metadata JSONB DEFAULT '{}',
|
||||
|
||||
-- Foreign key constraint to applications table
|
||||
CONSTRAINT fk_user_sessions_app_id FOREIGN KEY (app_id) REFERENCES applications(app_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_app_id ON user_sessions(app_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_user_app ON user_sessions(user_id, app_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_status ON user_sessions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_expires_at ON user_sessions(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_last_activity ON user_sessions(last_activity);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_created_at ON user_sessions(created_at);
|
||||
|
||||
-- Create partial indexes for active sessions (most common queries)
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_active ON user_sessions(user_id, app_id) WHERE status = 'active';
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_active_expires ON user_sessions(user_id, expires_at) WHERE status = 'active';
|
||||
|
||||
-- Create GIN index for metadata JSONB queries
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_metadata ON user_sessions USING GIN (metadata);
|
||||
|
||||
-- Create index for tenant-based queries (if using multi-tenancy)
|
||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_tenant ON user_sessions((metadata->>'tenant_id')) WHERE metadata->>'tenant_id' IS NOT NULL;
|
||||
|
||||
-- Add comments for documentation
|
||||
COMMENT ON TABLE user_sessions IS 'Stores user session information for authentication and session management';
|
||||
COMMENT ON COLUMN user_sessions.id IS 'Unique session identifier';
|
||||
COMMENT ON COLUMN user_sessions.user_id IS 'User identifier (email, username, or external ID)';
|
||||
COMMENT ON COLUMN user_sessions.app_id IS 'Application identifier this session belongs to';
|
||||
COMMENT ON COLUMN user_sessions.session_type IS 'Type of session: web, mobile, or api';
|
||||
COMMENT ON COLUMN user_sessions.status IS 'Current session status: active, expired, revoked, or suspended';
|
||||
COMMENT ON COLUMN user_sessions.access_token IS 'OAuth2/OIDC access token (encrypted/hashed)';
|
||||
COMMENT ON COLUMN user_sessions.refresh_token IS 'OAuth2/OIDC refresh token (encrypted/hashed)';
|
||||
COMMENT ON COLUMN user_sessions.id_token IS 'OIDC ID token (encrypted/hashed)';
|
||||
COMMENT ON COLUMN user_sessions.ip_address IS 'IP address of the client when session was created';
|
||||
COMMENT ON COLUMN user_sessions.user_agent IS 'User agent string of the client';
|
||||
COMMENT ON COLUMN user_sessions.last_activity IS 'Timestamp of last session activity';
|
||||
COMMENT ON COLUMN user_sessions.expires_at IS 'When the session expires';
|
||||
COMMENT ON COLUMN user_sessions.revoked_at IS 'When the session was revoked (if applicable)';
|
||||
COMMENT ON COLUMN user_sessions.revoked_by IS 'Who revoked the session (user ID or system)';
|
||||
COMMENT ON COLUMN user_sessions.metadata IS 'Additional session metadata (device info, location, etc.)';
|
||||
532
test/saml_test.go
Normal file
532
test/saml_test.go
Normal file
@ -0,0 +1,532 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
)
|
||||
|
||||
// mockSAMLMetadata returns a mock SAML IdP metadata XML
|
||||
func mockSAMLMetadata() string {
|
||||
return `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="https://idp.example.com">
|
||||
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<md:KeyDescriptor use="signing">
|
||||
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
|
||||
<ds:X509Data>
|
||||
<ds:X509Certificate>MIICertificateData</ds:X509Certificate>
|
||||
</ds:X509Data>
|
||||
</ds:KeyInfo>
|
||||
</md:KeyDescriptor>
|
||||
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
|
||||
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
|
||||
</md:IDPSSODescriptor>
|
||||
</md:EntityDescriptor>`
|
||||
}
|
||||
|
||||
// mockSAMLResponse returns a mock SAML response XML with current timestamps
|
||||
func mockSAMLResponse() string {
|
||||
now := time.Now().UTC()
|
||||
issueInstant := now.Format(time.RFC3339)
|
||||
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
|
||||
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
|
||||
ID="_response_id" Version="2.0" IssueInstant="%s"
|
||||
Destination="https://sp.example.com/acs" InResponseTo="_request_id">
|
||||
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
||||
<samlp:Status>
|
||||
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
|
||||
</samlp:Status>
|
||||
<saml:Assertion ID="_assertion_id" Version="2.0" IssueInstant="%s">
|
||||
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
||||
<saml:Subject>
|
||||
<saml:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress">user@example.com</saml:NameID>
|
||||
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
|
||||
<saml:SubjectConfirmationData InResponseTo="_request_id" NotOnOrAfter="%s" Recipient="https://sp.example.com/acs"/>
|
||||
</saml:SubjectConfirmation>
|
||||
</saml:Subject>
|
||||
<saml:Conditions NotBefore="%s" NotOnOrAfter="%s">
|
||||
<saml:AudienceRestriction>
|
||||
<saml:Audience>https://sp.example.com</saml:Audience>
|
||||
</saml:AudienceRestriction>
|
||||
</saml:Conditions>
|
||||
<saml:AttributeStatement>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress">
|
||||
<saml:AttributeValue>user@example.com</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name">
|
||||
<saml:AttributeValue>Test User</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname">
|
||||
<saml:AttributeValue>Test</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname">
|
||||
<saml:AttributeValue>User</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.microsoft.com/ws/2008/06/identity/claims/role">
|
||||
<saml:AttributeValue>admin,user</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
</saml:AttributeStatement>
|
||||
<saml:AuthnStatement AuthnInstant="%s" SessionIndex="_session_index">
|
||||
<saml:AuthnContext>
|
||||
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
|
||||
</saml:AuthnContext>
|
||||
</saml:AuthnStatement>
|
||||
</saml:Assertion>
|
||||
</samlp:Response>`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GetMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadataURL string
|
||||
serverResponse string
|
||||
serverStatus int
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful metadata fetch",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverResponse: mockSAMLMetadata(),
|
||||
serverStatus: http.StatusOK,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing metadata URL",
|
||||
metadataURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_IDP_METADATA_URL not configured",
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverStatus: http.StatusInternalServerError,
|
||||
expectError: true,
|
||||
errorContains: "returned status 500",
|
||||
},
|
||||
{
|
||||
name: "invalid XML",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverResponse: "invalid xml",
|
||||
serverStatus: http.StatusOK,
|
||||
expectError: true,
|
||||
errorContains: "Failed to parse SAML metadata",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock HTTP server
|
||||
var server *httptest.Server
|
||||
if tt.metadataURL != "" && tt.serverStatus > 0 {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tt.serverStatus)
|
||||
if tt.serverResponse != "" {
|
||||
w.Write([]byte(tt.serverResponse))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
tt.metadataURL = server.URL
|
||||
}
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GetMetadata
|
||||
ctx := context.Background()
|
||||
metadata, err := provider.GetMetadata(ctx)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, metadata)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, metadata)
|
||||
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
|
||||
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spEntityID string
|
||||
acsURL string
|
||||
relayState string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful auth request generation",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
relayState: "test-relay-state",
|
||||
},
|
||||
{
|
||||
name: "missing SP entity ID",
|
||||
spEntityID: "",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ENTITY_ID not configured",
|
||||
},
|
||||
{
|
||||
name: "missing ACS URL",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ACS_URL not configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock HTTP server for metadata
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockSAMLMetadata()))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GenerateAuthRequest
|
||||
ctx := context.Background()
|
||||
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Empty(t, authURL)
|
||||
assert.Empty(t, requestID)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, authURL)
|
||||
assert.NotEmpty(t, requestID)
|
||||
assert.Contains(t, authURL, "https://idp.example.com/sso")
|
||||
assert.Contains(t, authURL, "SAMLRequest=")
|
||||
if tt.relayState != "" {
|
||||
assert.Contains(t, authURL, "RelayState="+tt.relayState)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
samlResponse string
|
||||
expectedRequestID string
|
||||
spEntityID string
|
||||
expectError bool
|
||||
errorContains string
|
||||
expectedUserID string
|
||||
expectedEmail string
|
||||
expectedName string
|
||||
expectedRoles []string
|
||||
}{
|
||||
{
|
||||
name: "successful SAML response processing",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
||||
expectedRequestID: "_request_id",
|
||||
spEntityID: "https://sp.example.com",
|
||||
expectedUserID: "user@example.com",
|
||||
expectedEmail: "user@example.com",
|
||||
expectedName: "Test User",
|
||||
expectedRoles: []string{"admin", "user"},
|
||||
},
|
||||
{
|
||||
name: "invalid base64 encoding",
|
||||
samlResponse: "invalid-base64",
|
||||
expectError: true,
|
||||
errorContains: "Failed to decode SAML response",
|
||||
},
|
||||
{
|
||||
name: "invalid XML",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
|
||||
expectError: true,
|
||||
errorContains: "Failed to parse SAML response",
|
||||
},
|
||||
{
|
||||
name: "audience mismatch",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
||||
spEntityID: "https://wrong-sp.example.com",
|
||||
expectError: true,
|
||||
errorContains: "audience mismatch",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test ProcessSAMLResponse
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
assert.Equal(t, tt.expectedUserID, authContext.UserID)
|
||||
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
|
||||
|
||||
// Check claims
|
||||
if tt.expectedEmail != "" {
|
||||
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
|
||||
}
|
||||
if tt.expectedName != "" {
|
||||
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
|
||||
}
|
||||
|
||||
// Check permissions/roles
|
||||
if len(tt.expectedRoles) > 0 {
|
||||
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spEntityID string
|
||||
acsURL string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful SP metadata generation",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
},
|
||||
{
|
||||
name: "missing SP entity ID",
|
||||
spEntityID: "",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ENTITY_ID not configured",
|
||||
},
|
||||
{
|
||||
name: "missing ACS URL",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ACS_URL not configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GenerateServiceProviderMetadata
|
||||
metadata, err := provider.GenerateServiceProviderMetadata()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Empty(t, metadata)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, metadata)
|
||||
assert.Contains(t, metadata, tt.spEntityID)
|
||||
assert.Contains(t, metadata, tt.acsURL)
|
||||
assert.Contains(t, metadata, "EntityDescriptor")
|
||||
assert.Contains(t, metadata, "SPSSODescriptor")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for SAML operations
|
||||
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(b)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(b, err)
|
||||
|
||||
// Prepare SAML response
|
||||
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
|
||||
// Create mock HTTP server for metadata
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockSAMLMetadata()))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(b)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(b, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestSAMLResponseValidation(t *testing.T) {
|
||||
// Test various SAML response validation scenarios
|
||||
tests := []struct {
|
||||
name string
|
||||
modifyXML func(string) string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "expired assertion",
|
||||
modifyXML: func(xml string) string {
|
||||
// Replace all NotOnOrAfter timestamps with past time
|
||||
pastTime := "2020-01-01T13:00:00Z"
|
||||
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
|
||||
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "assertion has expired",
|
||||
},
|
||||
{
|
||||
name: "assertion not yet valid",
|
||||
modifyXML: func(xml string) string {
|
||||
// Replace all NotBefore timestamps with future time
|
||||
futureTime := "2030-01-01T11:55:00Z"
|
||||
re := regexp.MustCompile(`NotBefore="[^"]*"`)
|
||||
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "assertion not yet valid",
|
||||
},
|
||||
{
|
||||
name: "failed status",
|
||||
modifyXML: func(xml string) string {
|
||||
return strings.ReplaceAll(xml,
|
||||
"urn:oasis:names:tc:SAML:2.0:status:Success",
|
||||
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "SAML authentication failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify SAML response
|
||||
modifiedXML := tt.modifyXML(mockSAMLResponse())
|
||||
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
|
||||
|
||||
// Test ProcessSAMLResponse
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
705
test/token_repository_test.go
Normal file
705
test/token_repository_test.go
Normal file
@ -0,0 +1,705 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
"github.com/kms/api-key-service/internal/repository/postgres"
|
||||
)
|
||||
|
||||
// SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing
|
||||
type SQLMockDatabaseProvider struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) GetDB() interface{} {
|
||||
return m.db
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error {
|
||||
return m.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) Close() error {
|
||||
return m.db.Close()
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
|
||||
tx, err := m.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SQLMockTransactionProvider{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing
|
||||
type SQLMockTransactionProvider struct {
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
func (m *SQLMockTransactionProvider) Commit() error {
|
||||
return m.tx.Commit()
|
||||
}
|
||||
|
||||
func (m *SQLMockTransactionProvider) Rollback() error {
|
||||
return m.tx.Rollback()
|
||||
}
|
||||
|
||||
func (m *SQLMockTransactionProvider) GetTx() interface{} {
|
||||
return m.tx
|
||||
}
|
||||
|
||||
func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDB := &SQLMockDatabaseProvider{db: db}
|
||||
repo := postgres.NewStaticTokenRepository(mockDB)
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
return repo.(*postgres.StaticTokenRepository), mock, cleanup
|
||||
}
|
||||
|
||||
func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
mockDB := &SQLMockDatabaseProvider{db: db}
|
||||
repo := postgres.NewStaticTokenRepository(mockDB)
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
return repo.(*postgres.StaticTokenRepository), mock, cleanup
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_Create(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token *domain.StaticToken
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful creation",
|
||||
token: &domain.StaticToken{
|
||||
ID: uuid.New(),
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: "hmac",
|
||||
},
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`INSERT INTO static_tokens`).
|
||||
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
token: &domain.StaticToken{
|
||||
ID: uuid.New(),
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: "hmac",
|
||||
},
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`INSERT INTO static_tokens`).
|
||||
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create static token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
err := repo.Create(ctx, tt.token)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, tt.token.CreatedAt)
|
||||
assert.NotZero(t, tt.token.UpdatedAt)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_GetByID(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID uuid.UUID
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedToken *domain.StaticToken
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "individual", "test-user", "test-owner",
|
||||
"test-hash", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedToken: &domain.StaticToken{
|
||||
ID: tokenID,
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: string(domain.TokenTypeUser),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "not found",
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to get static token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
token, err := repo.GetByID(ctx, tt.tokenID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, token)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, token)
|
||||
assert.Equal(t, tt.expectedToken.ID, token.ID)
|
||||
assert.Equal(t, tt.expectedToken.AppID, token.AppID)
|
||||
assert.Equal(t, tt.expectedToken.Owner, token.Owner)
|
||||
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
|
||||
assert.Equal(t, tt.expectedToken.Type, token.Type)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_GetByKeyHash(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
keyHash := "test-hash"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyHash string
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedToken *domain.StaticToken
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
keyHash: keyHash,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "individual", "test-user", "test-owner",
|
||||
keyHash, "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
|
||||
WithArgs(keyHash).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedToken: &domain.StaticToken{
|
||||
ID: tokenID,
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: keyHash,
|
||||
Type: string(domain.TokenTypeUser),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
keyHash: keyHash,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
|
||||
WithArgs(keyHash).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
token, err := repo.GetByKeyHash(ctx, tt.keyHash)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, token)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, token)
|
||||
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_GetByAppID(t *testing.T) {
|
||||
tokenID1 := uuid.New()
|
||||
tokenID2 := uuid.New()
|
||||
now := time.Now()
|
||||
appID := "test-app"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
appID string
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval with multiple tokens",
|
||||
appID: appID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID1, appID, "user", "test-user1", "test-owner1",
|
||||
"test-hash1", "user", now, now,
|
||||
).AddRow(
|
||||
tokenID2, appID, "user", "test-user2", "test-owner2",
|
||||
"test-hash2", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
|
||||
WithArgs(appID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "no tokens found",
|
||||
appID: appID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
})
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
|
||||
WithArgs(appID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
appID: appID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
|
||||
WithArgs(appID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to query static tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
tokens, err := repo.GetByAppID(ctx, tt.appID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, tokens)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, tokens, tt.expectedCount)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_List(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
limit int
|
||||
offset int
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "successful list with pagination",
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "user", "test-user", "test-owner",
|
||||
"test-hash", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
|
||||
WithArgs(10, 0).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
|
||||
WithArgs(10, 0).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to query static tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
tokens, err := repo.List(ctx, tt.limit, tt.offset)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, tokens)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, tokens, tt.expectedCount)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_Delete(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID uuid.UUID
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful deletion",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "not found",
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to delete static token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
err := repo.Delete(ctx, tt.tokenID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_Exists(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID uuid.UUID
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedExists bool
|
||||
}{
|
||||
{
|
||||
name: "token exists",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"exists"}).AddRow(1)
|
||||
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedExists: true,
|
||||
},
|
||||
{
|
||||
name: "token does not exist",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedExists: false,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to check static token existence",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
exists, err := repo.Exists(ctx, tt.tokenID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedExists, exists)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for repository operations
|
||||
func BenchmarkStaticTokenRepository_Create(b *testing.B) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
|
||||
defer cleanup()
|
||||
|
||||
token := &domain.StaticToken{
|
||||
ID: uuid.New(),
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: string(domain.TokenTypeUser),
|
||||
}
|
||||
|
||||
// Setup mock expectations for all iterations
|
||||
for i := 0; i < b.N; i++ {
|
||||
mock.ExpectExec(`INSERT INTO static_tokens`).
|
||||
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
token.ID = uuid.New() // Generate new ID for each iteration
|
||||
err := repo.Create(ctx, token)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStaticTokenRepository_GetByID(b *testing.B) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
|
||||
defer cleanup()
|
||||
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
// Setup mock expectations for all iterations
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "user", "test-user", "test-owner",
|
||||
"test-hash", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnRows(rows)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := repo.GetByID(ctx, tokenID)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user