package postgres import ( "context" "database/sql" "encoding/json" "fmt" "strings" "time" "github.com/google/uuid" "github.com/lib/pq" "github.com/kms/api-key-service/internal/audit" "github.com/kms/api-key-service/internal/repository" ) // AuditRepository implements the AuditRepository interface for PostgreSQL type AuditRepository struct { db repository.DatabaseProvider } // NewAuditRepository creates a new PostgreSQL audit repository func NewAuditRepository(db repository.DatabaseProvider) repository.AuditRepository { return &AuditRepository{db: db} } // Create stores a new audit event func (r *AuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error { query := ` INSERT INTO audit_events ( id, type, severity, status, timestamp, actor_id, actor_type, actor_ip, user_agent, tenant_id, resource_id, resource_type, action, description, details, request_id, session_id, tags, metadata ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19 ) ` db := r.db.GetDB().(*sql.DB) // Ensure event has an ID and timestamp if event.ID == uuid.Nil { event.ID = uuid.New() } if event.Timestamp.IsZero() { event.Timestamp = time.Now().UTC() } // Convert details to JSON var detailsJSON []byte var err error if event.Details != nil { detailsJSON, err = json.Marshal(event.Details) if err != nil { return fmt.Errorf("failed to marshal event details: %w", err) } } else { detailsJSON = []byte("{}") } // Convert metadata to JSON var metadataJSON []byte if event.Metadata != nil { metadataJSON, err = json.Marshal(event.Metadata) if err != nil { return fmt.Errorf("failed to marshal event metadata: %w", err) } } else { metadataJSON = []byte("{}") } // Handle nullable fields var actorID, actorType, actorIP, userAgent *string var tenantID *uuid.UUID var resourceID, resourceType *string var requestID, sessionID *string if event.ActorID != "" { actorID = &event.ActorID } if event.ActorType != "" { actorType = &event.ActorType } if event.ActorIP != "" { actorIP = &event.ActorIP } if event.UserAgent != "" { userAgent = &event.UserAgent } if event.TenantID != nil { tenantID = event.TenantID } if event.ResourceID != "" { resourceID = &event.ResourceID } if event.ResourceType != "" { resourceType = &event.ResourceType } if event.RequestID != "" { requestID = &event.RequestID } if event.SessionID != "" { sessionID = &event.SessionID } _, err = db.ExecContext(ctx, query, event.ID, string(event.Type), string(event.Severity), string(event.Status), event.Timestamp, actorID, actorType, actorIP, userAgent, tenantID, resourceID, resourceType, event.Action, event.Description, string(detailsJSON), requestID, sessionID, pq.Array(event.Tags), string(metadataJSON), ) if err != nil { return fmt.Errorf("failed to create audit event: %w", err) } return nil } // Query retrieves audit events based on filter criteria func (r *AuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) { // Build dynamic query with filters var conditions []string var args []interface{} argIndex := 1 baseQuery := ` SELECT id, type, severity, status, timestamp, actor_id, actor_type, actor_ip, user_agent, tenant_id, resource_id, resource_type, action, description, details, request_id, session_id, tags, metadata FROM audit_events ` // Add filters if len(filter.EventTypes) > 0 { conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex)) typeStrings := make([]string, len(filter.EventTypes)) for i, t := range filter.EventTypes { typeStrings[i] = string(t) } args = append(args, pq.Array(typeStrings)) argIndex++ } if len(filter.Severities) > 0 { conditions = append(conditions, fmt.Sprintf("severity = ANY($%d)", argIndex)) severityStrings := make([]string, len(filter.Severities)) for i, s := range filter.Severities { severityStrings[i] = string(s) } args = append(args, pq.Array(severityStrings)) argIndex++ } if len(filter.Statuses) > 0 { conditions = append(conditions, fmt.Sprintf("status = ANY($%d)", argIndex)) statusStrings := make([]string, len(filter.Statuses)) for i, s := range filter.Statuses { statusStrings[i] = string(s) } args = append(args, pq.Array(statusStrings)) argIndex++ } if filter.ActorID != "" { conditions = append(conditions, fmt.Sprintf("actor_id = $%d", argIndex)) args = append(args, filter.ActorID) argIndex++ } if filter.ActorType != "" { conditions = append(conditions, fmt.Sprintf("actor_type = $%d", argIndex)) args = append(args, filter.ActorType) argIndex++ } if filter.TenantID != nil { conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex)) args = append(args, *filter.TenantID) argIndex++ } if filter.ResourceID != "" { conditions = append(conditions, fmt.Sprintf("resource_id = $%d", argIndex)) args = append(args, filter.ResourceID) argIndex++ } if filter.ResourceType != "" { conditions = append(conditions, fmt.Sprintf("resource_type = $%d", argIndex)) args = append(args, filter.ResourceType) argIndex++ } if filter.StartTime != nil { conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) args = append(args, *filter.StartTime) argIndex++ } if filter.EndTime != nil { conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) args = append(args, *filter.EndTime) argIndex++ } if len(filter.Tags) > 0 { conditions = append(conditions, fmt.Sprintf("tags && $%d", argIndex)) args = append(args, pq.Array(filter.Tags)) argIndex++ } // Build WHERE clause if len(conditions) > 0 { baseQuery += " WHERE " + strings.Join(conditions, " AND ") } // Add ORDER BY orderBy := "timestamp" if filter.OrderBy != "" { switch filter.OrderBy { case "timestamp", "type", "severity", "status": orderBy = filter.OrderBy } } direction := "DESC" if !filter.OrderDesc { direction = "ASC" } baseQuery += fmt.Sprintf(" ORDER BY %s %s", orderBy, direction) // Add pagination if filter.Limit <= 0 { filter.Limit = 100 } if filter.Limit > 1000 { filter.Limit = 1000 } baseQuery += fmt.Sprintf(" LIMIT $%d", argIndex) args = append(args, filter.Limit) argIndex++ if filter.Offset > 0 { baseQuery += fmt.Sprintf(" OFFSET $%d", argIndex) args = append(args, filter.Offset) } db := r.db.GetDB().(*sql.DB) rows, err := db.QueryContext(ctx, baseQuery, args...) if err != nil { return nil, fmt.Errorf("failed to query audit events: %w", err) } defer rows.Close() var events []*audit.AuditEvent for rows.Next() { event, err := r.scanAuditEvent(rows) if err != nil { return nil, fmt.Errorf("failed to scan audit event: %w", err) } events = append(events, event) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error iterating audit events: %w", err) } return events, nil } // GetStats returns aggregated statistics for audit events func (r *AuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) { stats := &audit.AuditStats{ ByType: make(map[audit.EventType]int), BySeverity: make(map[audit.EventSeverity]int), ByStatus: make(map[audit.EventStatus]int), } // Build base conditions var conditions []string var args []interface{} argIndex := 1 if len(filter.EventTypes) > 0 { conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex)) typeStrings := make([]string, len(filter.EventTypes)) for i, t := range filter.EventTypes { typeStrings[i] = string(t) } args = append(args, pq.Array(typeStrings)) argIndex++ } if filter.TenantID != nil { conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex)) args = append(args, *filter.TenantID) argIndex++ } if filter.StartTime != nil { conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) args = append(args, *filter.StartTime) argIndex++ } if filter.EndTime != nil { conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) args = append(args, *filter.EndTime) argIndex++ } whereClause := "" if len(conditions) > 0 { whereClause = "WHERE " + strings.Join(conditions, " AND ") } db := r.db.GetDB().(*sql.DB) // Get total count totalQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause) err := db.QueryRowContext(ctx, totalQuery, args...).Scan(&stats.TotalEvents) if err != nil { return nil, fmt.Errorf("failed to get total event count: %w", err) } // Get stats by type typeQuery := fmt.Sprintf(` SELECT type, COUNT(*) FROM audit_events %s GROUP BY type ORDER BY COUNT(*) DESC `, whereClause) rows, err := db.QueryContext(ctx, typeQuery, args...) if err != nil { return nil, fmt.Errorf("failed to get type stats: %w", err) } defer rows.Close() for rows.Next() { var eventType string var count int if err := rows.Scan(&eventType, &count); err != nil { return nil, fmt.Errorf("failed to scan type stats: %w", err) } stats.ByType[audit.EventType(eventType)] = count } // Get stats by severity severityQuery := fmt.Sprintf(` SELECT severity, COUNT(*) FROM audit_events %s GROUP BY severity ORDER BY COUNT(*) DESC `, whereClause) rows, err = db.QueryContext(ctx, severityQuery, args...) if err != nil { return nil, fmt.Errorf("failed to get severity stats: %w", err) } defer rows.Close() for rows.Next() { var severity string var count int if err := rows.Scan(&severity, &count); err != nil { return nil, fmt.Errorf("failed to scan severity stats: %w", err) } stats.BySeverity[audit.EventSeverity(severity)] = count } // Get stats by status statusQuery := fmt.Sprintf(` SELECT status, COUNT(*) FROM audit_events %s GROUP BY status ORDER BY COUNT(*) DESC `, whereClause) rows, err = db.QueryContext(ctx, statusQuery, args...) if err != nil { return nil, fmt.Errorf("failed to get status stats: %w", err) } defer rows.Close() for rows.Next() { var status string var count int if err := rows.Scan(&status, &count); err != nil { return nil, fmt.Errorf("failed to scan status stats: %w", err) } stats.ByStatus[audit.EventStatus(status)] = count } // Get time-based stats if requested if filter.GroupBy != "" { stats.ByTime = make(map[string]int) var timeFormat string switch filter.GroupBy { case "hour": timeFormat = "YYYY-MM-DD HH24:00" case "day": timeFormat = "YYYY-MM-DD" default: timeFormat = "YYYY-MM-DD" } timeQuery := fmt.Sprintf(` SELECT TO_CHAR(timestamp, '%s') as time_group, COUNT(*) FROM audit_events %s GROUP BY time_group ORDER BY time_group DESC `, timeFormat, whereClause) rows, err = db.QueryContext(ctx, timeQuery, args...) if err != nil { return nil, fmt.Errorf("failed to get time stats: %w", err) } defer rows.Close() for rows.Next() { var timeGroup string var count int if err := rows.Scan(&timeGroup, &count); err != nil { return nil, fmt.Errorf("failed to scan time stats: %w", err) } stats.ByTime[timeGroup] = count } } return stats, nil } // DeleteOldEvents removes audit events older than the specified time func (r *AuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) { query := `DELETE FROM audit_events WHERE timestamp < $1` db := r.db.GetDB().(*sql.DB) result, err := db.ExecContext(ctx, query, olderThan) if err != nil { return 0, fmt.Errorf("failed to delete old audit events: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return 0, fmt.Errorf("failed to get rows affected: %w", err) } return int(rowsAffected), nil } // GetByID retrieves a specific audit event by its ID func (r *AuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) { query := ` SELECT id, type, severity, status, timestamp, actor_id, actor_type, actor_ip, user_agent, tenant_id, resource_id, resource_type, action, description, details, request_id, session_id, tags, metadata FROM audit_events WHERE id = $1 ` db := r.db.GetDB().(*sql.DB) row := db.QueryRowContext(ctx, query, eventID) event, err := r.scanAuditEvent(row) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("audit event with ID '%s' not found", eventID) } return nil, fmt.Errorf("failed to get audit event: %w", err) } return event, nil } // GetByRequestID retrieves all audit events for a specific request func (r *AuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) { query := ` SELECT id, type, severity, status, timestamp, actor_id, actor_type, actor_ip, user_agent, tenant_id, resource_id, resource_type, action, description, details, request_id, session_id, tags, metadata FROM audit_events WHERE request_id = $1 ORDER BY timestamp ASC ` db := r.db.GetDB().(*sql.DB) rows, err := db.QueryContext(ctx, query, requestID) if err != nil { return nil, fmt.Errorf("failed to query audit events by request ID: %w", err) } defer rows.Close() var events []*audit.AuditEvent for rows.Next() { event, err := r.scanAuditEvent(rows) if err != nil { return nil, fmt.Errorf("failed to scan audit event: %w", err) } events = append(events, event) } return events, nil } // GetBySession retrieves all audit events for a specific session func (r *AuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) { query := ` SELECT id, type, severity, status, timestamp, actor_id, actor_type, actor_ip, user_agent, tenant_id, resource_id, resource_type, action, description, details, request_id, session_id, tags, metadata FROM audit_events WHERE session_id = $1 ORDER BY timestamp ASC ` db := r.db.GetDB().(*sql.DB) rows, err := db.QueryContext(ctx, query, sessionID) if err != nil { return nil, fmt.Errorf("failed to query audit events by session ID: %w", err) } defer rows.Close() var events []*audit.AuditEvent for rows.Next() { event, err := r.scanAuditEvent(rows) if err != nil { return nil, fmt.Errorf("failed to scan audit event: %w", err) } events = append(events, event) } return events, nil } // GetByActor retrieves audit events for a specific actor func (r *AuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) { if limit <= 0 { limit = 100 } if limit > 1000 { limit = 1000 } query := ` SELECT id, type, severity, status, timestamp, actor_id, actor_type, actor_ip, user_agent, tenant_id, resource_id, resource_type, action, description, details, request_id, session_id, tags, metadata FROM audit_events WHERE actor_id = $1 ORDER BY timestamp DESC LIMIT $2 OFFSET $3 ` db := r.db.GetDB().(*sql.DB) rows, err := db.QueryContext(ctx, query, actorID, limit, offset) if err != nil { return nil, fmt.Errorf("failed to query audit events by actor: %w", err) } defer rows.Close() var events []*audit.AuditEvent for rows.Next() { event, err := r.scanAuditEvent(rows) if err != nil { return nil, fmt.Errorf("failed to scan audit event: %w", err) } events = append(events, event) } return events, nil } // GetByResource retrieves audit events for a specific resource func (r *AuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) { if limit <= 0 { limit = 100 } if limit > 1000 { limit = 1000 } query := ` SELECT id, type, severity, status, timestamp, actor_id, actor_type, actor_ip, user_agent, tenant_id, resource_id, resource_type, action, description, details, request_id, session_id, tags, metadata FROM audit_events WHERE resource_type = $1 AND resource_id = $2 ORDER BY timestamp DESC LIMIT $3 OFFSET $4 ` db := r.db.GetDB().(*sql.DB) rows, err := db.QueryContext(ctx, query, resourceType, resourceID, limit, offset) if err != nil { return nil, fmt.Errorf("failed to query audit events by resource: %w", err) } defer rows.Close() var events []*audit.AuditEvent for rows.Next() { event, err := r.scanAuditEvent(rows) if err != nil { return nil, fmt.Errorf("failed to scan audit event: %w", err) } events = append(events, event) } return events, nil } // scanAuditEvent scans a database row into an AuditEvent struct func (r *AuditRepository) scanAuditEvent(row interface{}) (*audit.AuditEvent, error) { event := &audit.AuditEvent{} var typeStr, severityStr, statusStr string var actorID, actorType, actorIP, userAgent sql.NullString var tenantID *uuid.UUID var resourceID, resourceType sql.NullString var detailsJSON, metadataJSON string var requestID, sessionID sql.NullString var tags pq.StringArray var scanner interface { Scan(dest ...interface{}) error } switch v := row.(type) { case *sql.Row: scanner = v case *sql.Rows: scanner = v default: return nil, fmt.Errorf("invalid row type") } err := scanner.Scan( &event.ID, &typeStr, &severityStr, &statusStr, &event.Timestamp, &actorID, &actorType, &actorIP, &userAgent, &tenantID, &resourceID, &resourceType, &event.Action, &event.Description, &detailsJSON, &requestID, &sessionID, &tags, &metadataJSON, ) if err != nil { return nil, err } // Convert string enums to types event.Type = audit.EventType(typeStr) event.Severity = audit.EventSeverity(severityStr) event.Status = audit.EventStatus(statusStr) // Handle nullable fields if actorID.Valid { event.ActorID = actorID.String } if actorType.Valid { event.ActorType = actorType.String } if actorIP.Valid { event.ActorIP = actorIP.String } if userAgent.Valid { event.UserAgent = userAgent.String } if tenantID != nil { event.TenantID = tenantID } if resourceID.Valid { event.ResourceID = resourceID.String } if resourceType.Valid { event.ResourceType = resourceType.String } if requestID.Valid { event.RequestID = requestID.String } if sessionID.Valid { event.SessionID = sessionID.String } // Convert tags event.Tags = []string(tags) // Parse JSON fields if detailsJSON != "" { if err := json.Unmarshal([]byte(detailsJSON), &event.Details); err != nil { return nil, fmt.Errorf("failed to unmarshal details JSON: %w", err) } } if metadataJSON != "" { if err := json.Unmarshal([]byte(metadataJSON), &event.Metadata); err != nil { return nil, fmt.Errorf("failed to unmarshal metadata JSON: %w", err) } } return event, nil }