From 19364fcc76760a2822a3fbb092c8e1600803a8d4 Mon Sep 17 00:00:00 2001 From: Ryan Copley Date: Mon, 25 Aug 2025 21:28:14 -0400 Subject: [PATCH] - --- CLAUDE.md | 4 +- cmd/server/main.go | 3 +- internal/handlers/application.go | 25 +- internal/repository/interfaces.go | 31 + .../repository/postgres/audit_repository.go | 742 ++++++++++++++++++ internal/services/application_service.go | 79 +- migrations/004_add_audit_events.down.sql | 27 + migrations/004_add_audit_events.up.sql | 102 +++ test/e2e_test.sh | 1 + test/integration_test.go | 5 +- test/mock_repositories.go | 202 +++++ 11 files changed, 1208 insertions(+), 13 deletions(-) create mode 100644 internal/repository/postgres/audit_repository.go create mode 100644 migrations/004_add_audit_events.down.sql create mode 100644 migrations/004_add_audit_events.up.sql diff --git a/CLAUDE.md b/CLAUDE.md index 12e37c7..e3d19ce 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -146,8 +146,8 @@ The project uses podman-compose for all testing environments and database operat ### End-to-End Testing ```bash -# Start test environment with podman-compose -podman-compose up -d +# Start test environment with podman-compose, guaranteeing that it updates with --build +podman-compose up -d --build # Wait for services to be ready sleep 10 diff --git a/cmd/server/main.go b/cmd/server/main.go index 8d331a2..6b75575 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -63,9 +63,10 @@ func main() { tokenRepo := postgres.NewStaticTokenRepository(db) permRepo := postgres.NewPermissionRepository(db) grantRepo := postgres.NewGrantedPermissionRepository(db) + auditRepo := postgres.NewAuditRepository(db) // Initialize services - appService := services.NewApplicationService(appRepo, logger) + appService := services.NewApplicationService(appRepo, auditRepo, logger) tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, cfg.GetString("INTERNAL_HMAC_KEY"), cfg, logger) authService := services.NewAuthenticationService(cfg, logger, permRepo) diff --git a/internal/handlers/application.go b/internal/handlers/application.go index 8e2640f..7b4d1c4 100644 --- a/internal/handlers/application.go +++ b/internal/handlers/application.go @@ -55,8 +55,29 @@ func (h *ApplicationHandler) Create(c *gin.Context) { return } - // Validate input - validationErrors := h.validator.ValidateApplicationRequest(req.AppID, req.AppLink, req.CallbackURL, []string{}) + // Validate input (skip permissions validation for application creation) + var validationErrors []validation.ValidationError + + // Validate app ID + if result := h.validator.ValidateAppID(req.AppID); !result.Valid { + validationErrors = append(validationErrors, result.Errors...) + } + + // Validate app link URL + if result := h.validator.ValidateURL(req.AppLink, "app_link"); !result.Valid { + validationErrors = append(validationErrors, result.Errors...) + } + + // Validate callback URL + if result := h.validator.ValidateURL(req.CallbackURL, "callback_url"); !result.Valid { + validationErrors = append(validationErrors, result.Errors...) + } + + // Validate token prefix if provided + if result := h.validator.ValidateTokenPrefix(req.TokenPrefix); !result.Valid { + validationErrors = append(validationErrors, result.Errors...) + } + if len(validationErrors) > 0 { h.logger.Warn("Application validation failed", zap.String("user_id", userID), diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 5642b32..f4253a0 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -5,6 +5,7 @@ import ( "time" "github.com/google/uuid" + "github.com/kms/api-key-service/internal/audit" "github.com/kms/api-key-service/internal/domain" ) @@ -319,3 +320,33 @@ type MetricsProvider interface { // RecordDuration records the duration of an operation RecordDuration(ctx context.Context, name string, duration time.Duration, labels map[string]string) } + +// AuditRepository defines the interface for audit event storage operations +type AuditRepository interface { + // Create stores a new audit event + Create(ctx context.Context, event *audit.AuditEvent) error + + // Query retrieves audit events based on filter criteria + Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) + + // GetStats returns aggregated statistics for audit events + GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) + + // DeleteOldEvents removes audit events older than the specified time + DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) + + // GetByID retrieves a specific audit event by its ID + GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) + + // GetByRequestID retrieves all audit events for a specific request + GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) + + // GetBySession retrieves all audit events for a specific session + GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) + + // GetByActor retrieves audit events for a specific actor + GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) + + // GetByResource retrieves audit events for a specific resource + GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) +} diff --git a/internal/repository/postgres/audit_repository.go b/internal/repository/postgres/audit_repository.go new file mode 100644 index 0000000..5b61380 --- /dev/null +++ b/internal/repository/postgres/audit_repository.go @@ -0,0 +1,742 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + + "github.com/kms/api-key-service/internal/audit" + "github.com/kms/api-key-service/internal/repository" +) + +// AuditRepository implements the AuditRepository interface for PostgreSQL +type AuditRepository struct { + db repository.DatabaseProvider +} + +// NewAuditRepository creates a new PostgreSQL audit repository +func NewAuditRepository(db repository.DatabaseProvider) repository.AuditRepository { + return &AuditRepository{db: db} +} + +// Create stores a new audit event +func (r *AuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error { + query := ` + INSERT INTO audit_events ( + id, type, severity, status, timestamp, + actor_id, actor_type, actor_ip, user_agent, tenant_id, + resource_id, resource_type, action, description, details, + request_id, session_id, tags, metadata + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, + $11, $12, $13, $14, $15, $16, $17, $18, $19 + ) + ` + + db := r.db.GetDB().(*sql.DB) + + // Ensure event has an ID and timestamp + if event.ID == uuid.Nil { + event.ID = uuid.New() + } + if event.Timestamp.IsZero() { + event.Timestamp = time.Now().UTC() + } + + // Convert details to JSON + var detailsJSON []byte + var err error + if event.Details != nil { + detailsJSON, err = json.Marshal(event.Details) + if err != nil { + return fmt.Errorf("failed to marshal event details: %w", err) + } + } else { + detailsJSON = []byte("{}") + } + + // Convert metadata to JSON + var metadataJSON []byte + if event.Metadata != nil { + metadataJSON, err = json.Marshal(event.Metadata) + if err != nil { + return fmt.Errorf("failed to marshal event metadata: %w", err) + } + } else { + metadataJSON = []byte("{}") + } + + // Handle nullable fields + var actorID, actorType, actorIP, userAgent *string + var tenantID *uuid.UUID + var resourceID, resourceType *string + var requestID, sessionID *string + + if event.ActorID != "" { + actorID = &event.ActorID + } + if event.ActorType != "" { + actorType = &event.ActorType + } + if event.ActorIP != "" { + actorIP = &event.ActorIP + } + if event.UserAgent != "" { + userAgent = &event.UserAgent + } + if event.TenantID != nil { + tenantID = event.TenantID + } + if event.ResourceID != "" { + resourceID = &event.ResourceID + } + if event.ResourceType != "" { + resourceType = &event.ResourceType + } + if event.RequestID != "" { + requestID = &event.RequestID + } + if event.SessionID != "" { + sessionID = &event.SessionID + } + + _, err = db.ExecContext(ctx, query, + event.ID, + string(event.Type), + string(event.Severity), + string(event.Status), + event.Timestamp, + actorID, + actorType, + actorIP, + userAgent, + tenantID, + resourceID, + resourceType, + event.Action, + event.Description, + string(detailsJSON), + requestID, + sessionID, + pq.Array(event.Tags), + string(metadataJSON), + ) + + if err != nil { + return fmt.Errorf("failed to create audit event: %w", err) + } + + return nil +} + +// Query retrieves audit events based on filter criteria +func (r *AuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) { + // Build dynamic query with filters + var conditions []string + var args []interface{} + argIndex := 1 + + baseQuery := ` + SELECT id, type, severity, status, timestamp, + actor_id, actor_type, actor_ip, user_agent, tenant_id, + resource_id, resource_type, action, description, details, + request_id, session_id, tags, metadata + FROM audit_events + ` + + // Add filters + if len(filter.EventTypes) > 0 { + conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex)) + typeStrings := make([]string, len(filter.EventTypes)) + for i, t := range filter.EventTypes { + typeStrings[i] = string(t) + } + args = append(args, pq.Array(typeStrings)) + argIndex++ + } + + if len(filter.Severities) > 0 { + conditions = append(conditions, fmt.Sprintf("severity = ANY($%d)", argIndex)) + severityStrings := make([]string, len(filter.Severities)) + for i, s := range filter.Severities { + severityStrings[i] = string(s) + } + args = append(args, pq.Array(severityStrings)) + argIndex++ + } + + if len(filter.Statuses) > 0 { + conditions = append(conditions, fmt.Sprintf("status = ANY($%d)", argIndex)) + statusStrings := make([]string, len(filter.Statuses)) + for i, s := range filter.Statuses { + statusStrings[i] = string(s) + } + args = append(args, pq.Array(statusStrings)) + argIndex++ + } + + if filter.ActorID != "" { + conditions = append(conditions, fmt.Sprintf("actor_id = $%d", argIndex)) + args = append(args, filter.ActorID) + argIndex++ + } + + if filter.ActorType != "" { + conditions = append(conditions, fmt.Sprintf("actor_type = $%d", argIndex)) + args = append(args, filter.ActorType) + argIndex++ + } + + if filter.TenantID != nil { + conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex)) + args = append(args, *filter.TenantID) + argIndex++ + } + + if filter.ResourceID != "" { + conditions = append(conditions, fmt.Sprintf("resource_id = $%d", argIndex)) + args = append(args, filter.ResourceID) + argIndex++ + } + + if filter.ResourceType != "" { + conditions = append(conditions, fmt.Sprintf("resource_type = $%d", argIndex)) + args = append(args, filter.ResourceType) + argIndex++ + } + + if filter.StartTime != nil { + conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) + args = append(args, *filter.StartTime) + argIndex++ + } + + if filter.EndTime != nil { + conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) + args = append(args, *filter.EndTime) + argIndex++ + } + + if len(filter.Tags) > 0 { + conditions = append(conditions, fmt.Sprintf("tags && $%d", argIndex)) + args = append(args, pq.Array(filter.Tags)) + argIndex++ + } + + // Build WHERE clause + if len(conditions) > 0 { + baseQuery += " WHERE " + strings.Join(conditions, " AND ") + } + + // Add ORDER BY + orderBy := "timestamp" + if filter.OrderBy != "" { + switch filter.OrderBy { + case "timestamp", "type", "severity", "status": + orderBy = filter.OrderBy + } + } + + direction := "DESC" + if !filter.OrderDesc { + direction = "ASC" + } + + baseQuery += fmt.Sprintf(" ORDER BY %s %s", orderBy, direction) + + // Add pagination + if filter.Limit <= 0 { + filter.Limit = 100 + } + if filter.Limit > 1000 { + filter.Limit = 1000 + } + + baseQuery += fmt.Sprintf(" LIMIT $%d", argIndex) + args = append(args, filter.Limit) + argIndex++ + + if filter.Offset > 0 { + baseQuery += fmt.Sprintf(" OFFSET $%d", argIndex) + args = append(args, filter.Offset) + } + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, baseQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to query audit events: %w", err) + } + defer rows.Close() + + var events []*audit.AuditEvent + for rows.Next() { + event, err := r.scanAuditEvent(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan audit event: %w", err) + } + events = append(events, event) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating audit events: %w", err) + } + + return events, nil +} + +// GetStats returns aggregated statistics for audit events +func (r *AuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) { + stats := &audit.AuditStats{ + ByType: make(map[audit.EventType]int), + BySeverity: make(map[audit.EventSeverity]int), + ByStatus: make(map[audit.EventStatus]int), + } + + // Build base conditions + var conditions []string + var args []interface{} + argIndex := 1 + + if len(filter.EventTypes) > 0 { + conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex)) + typeStrings := make([]string, len(filter.EventTypes)) + for i, t := range filter.EventTypes { + typeStrings[i] = string(t) + } + args = append(args, pq.Array(typeStrings)) + argIndex++ + } + + if filter.TenantID != nil { + conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex)) + args = append(args, *filter.TenantID) + argIndex++ + } + + if filter.StartTime != nil { + conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) + args = append(args, *filter.StartTime) + argIndex++ + } + + if filter.EndTime != nil { + conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) + args = append(args, *filter.EndTime) + argIndex++ + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + db := r.db.GetDB().(*sql.DB) + + // Get total count + totalQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause) + err := db.QueryRowContext(ctx, totalQuery, args...).Scan(&stats.TotalEvents) + if err != nil { + return nil, fmt.Errorf("failed to get total event count: %w", err) + } + + // Get stats by type + typeQuery := fmt.Sprintf(` + SELECT type, COUNT(*) + FROM audit_events %s + GROUP BY type + ORDER BY COUNT(*) DESC + `, whereClause) + + rows, err := db.QueryContext(ctx, typeQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to get type stats: %w", err) + } + defer rows.Close() + + for rows.Next() { + var eventType string + var count int + if err := rows.Scan(&eventType, &count); err != nil { + return nil, fmt.Errorf("failed to scan type stats: %w", err) + } + stats.ByType[audit.EventType(eventType)] = count + } + + // Get stats by severity + severityQuery := fmt.Sprintf(` + SELECT severity, COUNT(*) + FROM audit_events %s + GROUP BY severity + ORDER BY COUNT(*) DESC + `, whereClause) + + rows, err = db.QueryContext(ctx, severityQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to get severity stats: %w", err) + } + defer rows.Close() + + for rows.Next() { + var severity string + var count int + if err := rows.Scan(&severity, &count); err != nil { + return nil, fmt.Errorf("failed to scan severity stats: %w", err) + } + stats.BySeverity[audit.EventSeverity(severity)] = count + } + + // Get stats by status + statusQuery := fmt.Sprintf(` + SELECT status, COUNT(*) + FROM audit_events %s + GROUP BY status + ORDER BY COUNT(*) DESC + `, whereClause) + + rows, err = db.QueryContext(ctx, statusQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to get status stats: %w", err) + } + defer rows.Close() + + for rows.Next() { + var status string + var count int + if err := rows.Scan(&status, &count); err != nil { + return nil, fmt.Errorf("failed to scan status stats: %w", err) + } + stats.ByStatus[audit.EventStatus(status)] = count + } + + // Get time-based stats if requested + if filter.GroupBy != "" { + stats.ByTime = make(map[string]int) + + var timeFormat string + switch filter.GroupBy { + case "hour": + timeFormat = "YYYY-MM-DD HH24:00" + case "day": + timeFormat = "YYYY-MM-DD" + default: + timeFormat = "YYYY-MM-DD" + } + + timeQuery := fmt.Sprintf(` + SELECT TO_CHAR(timestamp, '%s') as time_group, COUNT(*) + FROM audit_events %s + GROUP BY time_group + ORDER BY time_group DESC + `, timeFormat, whereClause) + + rows, err = db.QueryContext(ctx, timeQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to get time stats: %w", err) + } + defer rows.Close() + + for rows.Next() { + var timeGroup string + var count int + if err := rows.Scan(&timeGroup, &count); err != nil { + return nil, fmt.Errorf("failed to scan time stats: %w", err) + } + stats.ByTime[timeGroup] = count + } + } + + return stats, nil +} + +// DeleteOldEvents removes audit events older than the specified time +func (r *AuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) { + query := `DELETE FROM audit_events WHERE timestamp < $1` + + db := r.db.GetDB().(*sql.DB) + result, err := db.ExecContext(ctx, query, olderThan) + if err != nil { + return 0, fmt.Errorf("failed to delete old audit events: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("failed to get rows affected: %w", err) + } + + return int(rowsAffected), nil +} + +// GetByID retrieves a specific audit event by its ID +func (r *AuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) { + query := ` + SELECT id, type, severity, status, timestamp, + actor_id, actor_type, actor_ip, user_agent, tenant_id, + resource_id, resource_type, action, description, details, + request_id, session_id, tags, metadata + FROM audit_events + WHERE id = $1 + ` + + db := r.db.GetDB().(*sql.DB) + row := db.QueryRowContext(ctx, query, eventID) + + event, err := r.scanAuditEvent(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("audit event with ID '%s' not found", eventID) + } + return nil, fmt.Errorf("failed to get audit event: %w", err) + } + + return event, nil +} + +// GetByRequestID retrieves all audit events for a specific request +func (r *AuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) { + query := ` + SELECT id, type, severity, status, timestamp, + actor_id, actor_type, actor_ip, user_agent, tenant_id, + resource_id, resource_type, action, description, details, + request_id, session_id, tags, metadata + FROM audit_events + WHERE request_id = $1 + ORDER BY timestamp ASC + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, requestID) + if err != nil { + return nil, fmt.Errorf("failed to query audit events by request ID: %w", err) + } + defer rows.Close() + + var events []*audit.AuditEvent + for rows.Next() { + event, err := r.scanAuditEvent(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan audit event: %w", err) + } + events = append(events, event) + } + + return events, nil +} + +// GetBySession retrieves all audit events for a specific session +func (r *AuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) { + query := ` + SELECT id, type, severity, status, timestamp, + actor_id, actor_type, actor_ip, user_agent, tenant_id, + resource_id, resource_type, action, description, details, + request_id, session_id, tags, metadata + FROM audit_events + WHERE session_id = $1 + ORDER BY timestamp ASC + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, sessionID) + if err != nil { + return nil, fmt.Errorf("failed to query audit events by session ID: %w", err) + } + defer rows.Close() + + var events []*audit.AuditEvent + for rows.Next() { + event, err := r.scanAuditEvent(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan audit event: %w", err) + } + events = append(events, event) + } + + return events, nil +} + +// GetByActor retrieves audit events for a specific actor +func (r *AuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) { + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + query := ` + SELECT id, type, severity, status, timestamp, + actor_id, actor_type, actor_ip, user_agent, tenant_id, + resource_id, resource_type, action, description, details, + request_id, session_id, tags, metadata + FROM audit_events + WHERE actor_id = $1 + ORDER BY timestamp DESC + LIMIT $2 OFFSET $3 + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, actorID, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query audit events by actor: %w", err) + } + defer rows.Close() + + var events []*audit.AuditEvent + for rows.Next() { + event, err := r.scanAuditEvent(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan audit event: %w", err) + } + events = append(events, event) + } + + return events, nil +} + +// GetByResource retrieves audit events for a specific resource +func (r *AuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) { + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + query := ` + SELECT id, type, severity, status, timestamp, + actor_id, actor_type, actor_ip, user_agent, tenant_id, + resource_id, resource_type, action, description, details, + request_id, session_id, tags, metadata + FROM audit_events + WHERE resource_type = $1 AND resource_id = $2 + ORDER BY timestamp DESC + LIMIT $3 OFFSET $4 + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, resourceType, resourceID, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query audit events by resource: %w", err) + } + defer rows.Close() + + var events []*audit.AuditEvent + for rows.Next() { + event, err := r.scanAuditEvent(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan audit event: %w", err) + } + events = append(events, event) + } + + return events, nil +} + +// scanAuditEvent scans a database row into an AuditEvent struct +func (r *AuditRepository) scanAuditEvent(row interface{}) (*audit.AuditEvent, error) { + event := &audit.AuditEvent{} + + var typeStr, severityStr, statusStr string + var actorID, actorType, actorIP, userAgent sql.NullString + var tenantID *uuid.UUID + var resourceID, resourceType sql.NullString + var detailsJSON, metadataJSON string + var requestID, sessionID sql.NullString + var tags pq.StringArray + + var scanner interface { + Scan(dest ...interface{}) error + } + + switch v := row.(type) { + case *sql.Row: + scanner = v + case *sql.Rows: + scanner = v + default: + return nil, fmt.Errorf("invalid row type") + } + + err := scanner.Scan( + &event.ID, + &typeStr, + &severityStr, + &statusStr, + &event.Timestamp, + &actorID, + &actorType, + &actorIP, + &userAgent, + &tenantID, + &resourceID, + &resourceType, + &event.Action, + &event.Description, + &detailsJSON, + &requestID, + &sessionID, + &tags, + &metadataJSON, + ) + + if err != nil { + return nil, err + } + + // Convert string enums to types + event.Type = audit.EventType(typeStr) + event.Severity = audit.EventSeverity(severityStr) + event.Status = audit.EventStatus(statusStr) + + // Handle nullable fields + if actorID.Valid { + event.ActorID = actorID.String + } + if actorType.Valid { + event.ActorType = actorType.String + } + if actorIP.Valid { + event.ActorIP = actorIP.String + } + if userAgent.Valid { + event.UserAgent = userAgent.String + } + if tenantID != nil { + event.TenantID = tenantID + } + if resourceID.Valid { + event.ResourceID = resourceID.String + } + if resourceType.Valid { + event.ResourceType = resourceType.String + } + if requestID.Valid { + event.RequestID = requestID.String + } + if sessionID.Valid { + event.SessionID = sessionID.String + } + + // Convert tags + event.Tags = []string(tags) + + // Parse JSON fields + if detailsJSON != "" { + if err := json.Unmarshal([]byte(detailsJSON), &event.Details); err != nil { + return nil, fmt.Errorf("failed to unmarshal details JSON: %w", err) + } + } + + if metadataJSON != "" { + if err := json.Unmarshal([]byte(metadataJSON), &event.Metadata); err != nil { + return nil, fmt.Errorf("failed to unmarshal metadata JSON: %w", err) + } + } + + return event, nil +} \ No newline at end of file diff --git a/internal/services/application_service.go b/internal/services/application_service.go index 355dd1f..0540aab 100644 --- a/internal/services/application_service.go +++ b/internal/services/application_service.go @@ -5,30 +5,61 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "time" "github.com/go-playground/validator/v10" "go.uber.org/zap" + "github.com/kms/api-key-service/internal/audit" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/repository" ) // applicationService implements the ApplicationService interface type applicationService struct { - appRepo repository.ApplicationRepository - logger *zap.Logger - validator *validator.Validate + appRepo repository.ApplicationRepository + auditRepo repository.AuditRepository + auditLogger audit.AuditLogger + logger *zap.Logger + validator *validator.Validate } // NewApplicationService creates a new application service -func NewApplicationService(appRepo repository.ApplicationRepository, logger *zap.Logger) ApplicationService { +func NewApplicationService(appRepo repository.ApplicationRepository, auditRepo repository.AuditRepository, logger *zap.Logger) ApplicationService { + // Create audit logger with audit package's repository interface + auditRepoImpl := &auditRepositoryAdapter{repo: auditRepo} + auditLogger := audit.NewAuditLogger(nil, logger, auditRepoImpl) // config can be nil for now + return &applicationService{ - appRepo: appRepo, - logger: logger, - validator: validator.New(), + appRepo: appRepo, + auditRepo: auditRepo, + auditLogger: auditLogger, + logger: logger, + validator: validator.New(), } } +// auditRepositoryAdapter adapts repository.AuditRepository to audit.AuditRepository +type auditRepositoryAdapter struct { + repo repository.AuditRepository +} + +func (a *auditRepositoryAdapter) Create(ctx context.Context, event *audit.AuditEvent) error { + return a.repo.Create(ctx, event) +} + +func (a *auditRepositoryAdapter) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) { + return a.repo.Query(ctx, filter) +} + +func (a *auditRepositoryAdapter) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) { + return a.repo.GetStats(ctx, filter) +} + +func (a *auditRepositoryAdapter) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) { + return a.repo.DeleteOldEvents(ctx, olderThan) +} + // Create creates a new application func (s *applicationService) Create(ctx context.Context, req *domain.CreateApplicationRequest, userID string) (*domain.Application, error) { s.logger.Info("Creating application", zap.String("app_id", req.AppID), zap.String("user_id", userID)) @@ -75,9 +106,43 @@ func (s *applicationService) Create(ctx context.Context, req *domain.CreateAppli if err := s.appRepo.Create(ctx, app); err != nil { s.logger.Error("Failed to create application", zap.Error(err), zap.String("app_id", req.AppID)) + + // Log audit event for failed creation + s.auditLogger.LogEvent(ctx, audit.NewAuditEventBuilder(audit.EventTypeAppCreated). + WithSeverity(audit.SeverityError). + WithStatus(audit.StatusFailure). + WithActor(userID, "user", ""). + WithResource(req.AppID, "application"). + WithAction("create"). + WithDescription(fmt.Sprintf("Failed to create application %s", req.AppID)). + WithDetails(map[string]interface{}{ + "error": err.Error(), + "app_id": req.AppID, + "user_id": userID, + }). + Build()) + return nil, fmt.Errorf("failed to create application: %w", err) } + // Log successful creation + s.auditLogger.LogEvent(ctx, audit.NewAuditEventBuilder(audit.EventTypeAppCreated). + WithSeverity(audit.SeverityInfo). + WithStatus(audit.StatusSuccess). + WithActor(userID, "user", ""). + WithResource(app.AppID, "application"). + WithAction("create"). + WithDescription(fmt.Sprintf("Created application %s", app.AppID)). + WithDetails(map[string]interface{}{ + "app_id": app.AppID, + "app_link": app.AppLink, + "type": app.Type, + "user_id": userID, + "owner_name": app.Owner.Name, + "owner_type": app.Owner.Type, + }). + Build()) + s.logger.Info("Application created successfully", zap.String("app_id", app.AppID)) return app, nil } diff --git a/migrations/004_add_audit_events.down.sql b/migrations/004_add_audit_events.down.sql new file mode 100644 index 0000000..c625ad0 --- /dev/null +++ b/migrations/004_add_audit_events.down.sql @@ -0,0 +1,27 @@ +-- Migration: 004_add_audit_events (down) +-- Remove audit_events table and related objects + +-- Drop the cleanup function +DROP FUNCTION IF EXISTS cleanup_old_audit_events(INTEGER); + +-- Drop indexes first (they will be dropped automatically with the table, but explicit for clarity) +DROP INDEX IF EXISTS idx_audit_events_timestamp; +DROP INDEX IF EXISTS idx_audit_events_type; +DROP INDEX IF EXISTS idx_audit_events_severity; +DROP INDEX IF EXISTS idx_audit_events_status; +DROP INDEX IF EXISTS idx_audit_events_actor_id; +DROP INDEX IF EXISTS idx_audit_events_actor_type; +DROP INDEX IF EXISTS idx_audit_events_tenant_id; +DROP INDEX IF EXISTS idx_audit_events_resource; +DROP INDEX IF EXISTS idx_audit_events_request_id; +DROP INDEX IF EXISTS idx_audit_events_session_id; +DROP INDEX IF EXISTS idx_audit_events_details; +DROP INDEX IF EXISTS idx_audit_events_metadata; +DROP INDEX IF EXISTS idx_audit_events_tags; +DROP INDEX IF EXISTS idx_audit_events_actor_timestamp; +DROP INDEX IF EXISTS idx_audit_events_type_timestamp; +DROP INDEX IF EXISTS idx_audit_events_tenant_timestamp; +DROP INDEX IF EXISTS idx_audit_events_resource_timestamp; + +-- Drop the audit_events table +DROP TABLE IF EXISTS audit_events; \ No newline at end of file diff --git a/migrations/004_add_audit_events.up.sql b/migrations/004_add_audit_events.up.sql new file mode 100644 index 0000000..0fcd5ca --- /dev/null +++ b/migrations/004_add_audit_events.up.sql @@ -0,0 +1,102 @@ +-- Migration: 004_add_audit_events +-- Add audit_events table for comprehensive audit logging + +-- Create audit_events table +CREATE TABLE IF NOT EXISTS audit_events ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + type VARCHAR(50) NOT NULL, + severity VARCHAR(20) NOT NULL CHECK (severity IN ('info', 'warning', 'error', 'critical')), + status VARCHAR(20) NOT NULL CHECK (status IN ('success', 'failure', 'pending')), + timestamp TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + + -- Actor information + actor_id VARCHAR(255), + actor_type VARCHAR(50) CHECK (actor_type IN ('user', 'system', 'service')), + actor_ip INET, + user_agent TEXT, + + -- Tenant information (for multi-tenancy support) + tenant_id UUID, + + -- Resource information + resource_id VARCHAR(255), + resource_type VARCHAR(100), + + -- Event details + action VARCHAR(100) NOT NULL, + description TEXT NOT NULL, + details JSONB DEFAULT '{}', + + -- Request context + request_id VARCHAR(100), + session_id VARCHAR(255), + + -- Additional metadata + tags TEXT[], + metadata JSONB DEFAULT '{}' +); + +-- Create indexes for efficient querying +CREATE INDEX IF NOT EXISTS idx_audit_events_timestamp ON audit_events(timestamp DESC); +CREATE INDEX IF NOT EXISTS idx_audit_events_type ON audit_events(type); +CREATE INDEX IF NOT EXISTS idx_audit_events_severity ON audit_events(severity); +CREATE INDEX IF NOT EXISTS idx_audit_events_status ON audit_events(status); +CREATE INDEX IF NOT EXISTS idx_audit_events_actor_id ON audit_events(actor_id); +CREATE INDEX IF NOT EXISTS idx_audit_events_actor_type ON audit_events(actor_type); +CREATE INDEX IF NOT EXISTS idx_audit_events_tenant_id ON audit_events(tenant_id) WHERE tenant_id IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_audit_events_resource ON audit_events(resource_type, resource_id) WHERE resource_id IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_audit_events_request_id ON audit_events(request_id) WHERE request_id IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_audit_events_session_id ON audit_events(session_id) WHERE session_id IS NOT NULL; + +-- GIN indexes for JSONB columns +CREATE INDEX IF NOT EXISTS idx_audit_events_details ON audit_events USING GIN (details); +CREATE INDEX IF NOT EXISTS idx_audit_events_metadata ON audit_events USING GIN (metadata); + +-- GIN index for tags array +CREATE INDEX IF NOT EXISTS idx_audit_events_tags ON audit_events USING GIN (tags); + +-- Composite indexes for common query patterns +CREATE INDEX IF NOT EXISTS idx_audit_events_actor_timestamp ON audit_events(actor_id, timestamp DESC) WHERE actor_id IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_audit_events_type_timestamp ON audit_events(type, timestamp DESC); +CREATE INDEX IF NOT EXISTS idx_audit_events_tenant_timestamp ON audit_events(tenant_id, timestamp DESC) WHERE tenant_id IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_audit_events_resource_timestamp ON audit_events(resource_type, resource_id, timestamp DESC) WHERE resource_id IS NOT NULL; + +-- Add comments for documentation +COMMENT ON TABLE audit_events IS 'Comprehensive audit log for all system events and user actions'; +COMMENT ON COLUMN audit_events.id IS 'Unique event identifier'; +COMMENT ON COLUMN audit_events.type IS 'Event type (e.g., auth.login, app.created)'; +COMMENT ON COLUMN audit_events.severity IS 'Event severity level: info, warning, error, critical'; +COMMENT ON COLUMN audit_events.status IS 'Event status: success, failure, pending'; +COMMENT ON COLUMN audit_events.timestamp IS 'When the event occurred'; +COMMENT ON COLUMN audit_events.actor_id IS 'ID of the user/system that triggered the event'; +COMMENT ON COLUMN audit_events.actor_type IS 'Type of actor: user, system, service'; +COMMENT ON COLUMN audit_events.actor_ip IS 'IP address of the actor'; +COMMENT ON COLUMN audit_events.user_agent IS 'User agent string (for HTTP requests)'; +COMMENT ON COLUMN audit_events.tenant_id IS 'Tenant ID for multi-tenant environments'; +COMMENT ON COLUMN audit_events.resource_id IS 'ID of the resource being acted upon'; +COMMENT ON COLUMN audit_events.resource_type IS 'Type of resource (e.g., application, token)'; +COMMENT ON COLUMN audit_events.action IS 'Action performed'; +COMMENT ON COLUMN audit_events.description IS 'Human-readable description of the event'; +COMMENT ON COLUMN audit_events.details IS 'Additional structured details as JSON'; +COMMENT ON COLUMN audit_events.request_id IS 'Request ID for tracing'; +COMMENT ON COLUMN audit_events.session_id IS 'Session ID for user session tracking'; +COMMENT ON COLUMN audit_events.tags IS 'Array of tags for categorization'; +COMMENT ON COLUMN audit_events.metadata IS 'Additional metadata as JSON'; + +-- Create a function to automatically clean up old audit events (optional) +CREATE OR REPLACE FUNCTION cleanup_old_audit_events(retention_days INTEGER DEFAULT 365) +RETURNS INTEGER AS $$ +DECLARE + deleted_count INTEGER; +BEGIN + -- Delete audit events older than retention period + DELETE FROM audit_events + WHERE timestamp < NOW() - (retention_days || ' days')::INTERVAL; + + GET DIAGNOSTICS deleted_count = ROW_COUNT; + + RETURN deleted_count; +END; +$$ LANGUAGE plpgsql; + +COMMENT ON FUNCTION cleanup_old_audit_events(INTEGER) IS 'Function to clean up audit events older than specified days (default: 365 days)'; \ No newline at end of file diff --git a/test/e2e_test.sh b/test/e2e_test.sh index c062e8b..bd7cca7 100755 --- a/test/e2e_test.sh +++ b/test/e2e_test.sh @@ -189,6 +189,7 @@ test_application_endpoints() { "app_link": "https://example.com/test-app", "type": ["static"], "callback_url": "https://example.com/callback", + "token_prefix": "TEST", "token_renewal_duration": 604800000000000, "max_token_duration": 2592000000000000, "owner": { diff --git a/test/integration_test.go b/test/integration_test.go index 6a8f773..1726ccc 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -96,8 +96,11 @@ func (suite *IntegrationTestSuite) setupServer() { // Create a no-op logger for tests logger := zap.NewNop() + // Initialize repositories + auditRepo := NewMockAuditRepository() + // Initialize services - appService := services.NewApplicationService(appRepo, logger) + appService := services.NewApplicationService(appRepo, auditRepo, logger) tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, suite.cfg.GetString("INTERNAL_HMAC_KEY"), suite.cfg, logger) authService := services.NewAuthenticationService(suite.cfg, logger, permRepo) diff --git a/test/mock_repositories.go b/test/mock_repositories.go index feef2ba..5735fc8 100644 --- a/test/mock_repositories.go +++ b/test/mock_repositories.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/uuid" + "github.com/kms/api-key-service/internal/audit" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/repository" ) @@ -612,3 +613,204 @@ func (m *MockGrantedPermissionRepository) HasAnyPermission(ctx context.Context, return result, nil } + +// MockAuditRepository implements AuditRepository for testing +type MockAuditRepository struct { + mu sync.RWMutex + events []*audit.AuditEvent +} + +func NewMockAuditRepository() repository.AuditRepository { + return &MockAuditRepository{ + events: make([]*audit.AuditEvent, 0), + } +} + +func (m *MockAuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error { + m.mu.Lock() + defer m.mu.Unlock() + + if event.ID == uuid.Nil { + event.ID = uuid.New() + } + if event.Timestamp.IsZero() { + event.Timestamp = time.Now().UTC() + } + + m.events = append(m.events, event) + return nil +} + +func (m *MockAuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + var result []*audit.AuditEvent + for _, event := range m.events { + // Simple filtering logic for testing + if len(filter.EventTypes) > 0 { + found := false + for _, t := range filter.EventTypes { + if event.Type == t { + found = true + break + } + } + if !found { + continue + } + } + + if filter.ActorID != "" && event.ActorID != filter.ActorID { + continue + } + + if filter.ResourceID != "" && event.ResourceID != filter.ResourceID { + continue + } + + if filter.ResourceType != "" && event.ResourceType != filter.ResourceType { + continue + } + + result = append(result, event) + } + + // Apply pagination + if filter.Offset >= len(result) { + return []*audit.AuditEvent{}, nil + } + + end := filter.Offset + filter.Limit + if end > len(result) { + end = len(result) + } + + return result[filter.Offset:end], nil +} + +func (m *MockAuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + stats := &audit.AuditStats{ + TotalEvents: len(m.events), + ByType: make(map[audit.EventType]int), + BySeverity: make(map[audit.EventSeverity]int), + ByStatus: make(map[audit.EventStatus]int), + } + + for _, event := range m.events { + stats.ByType[event.Type]++ + stats.BySeverity[event.Severity]++ + stats.ByStatus[event.Status]++ + } + + return stats, nil +} + +func (m *MockAuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + var kept []*audit.AuditEvent + deleted := 0 + + for _, event := range m.events { + if event.Timestamp.Before(olderThan) { + deleted++ + } else { + kept = append(kept, event) + } + } + + m.events = kept + return deleted, nil +} + +func (m *MockAuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, event := range m.events { + if event.ID == eventID { + return event, nil + } + } + + return nil, fmt.Errorf("audit event with ID '%s' not found", eventID) +} + +func (m *MockAuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + var result []*audit.AuditEvent + for _, event := range m.events { + if event.RequestID == requestID { + result = append(result, event) + } + } + + return result, nil +} + +func (m *MockAuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + var result []*audit.AuditEvent + for _, event := range m.events { + if event.SessionID == sessionID { + result = append(result, event) + } + } + + return result, nil +} + +func (m *MockAuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + var matching []*audit.AuditEvent + for _, event := range m.events { + if event.ActorID == actorID { + matching = append(matching, event) + } + } + + if offset >= len(matching) { + return []*audit.AuditEvent{}, nil + } + + end := offset + limit + if end > len(matching) { + end = len(matching) + } + + return matching[offset:end], nil +} + +func (m *MockAuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + var matching []*audit.AuditEvent + for _, event := range m.events { + if event.ResourceType == resourceType && event.ResourceID == resourceID { + matching = append(matching, event) + } + } + + if offset >= len(matching) { + return []*audit.AuditEvent{}, nil + } + + end := offset + limit + if end > len(matching) { + end = len(matching) + } + + return matching[offset:end], nil +}