Files
skybridge/kms/internal/repository/postgres/audit_repository.go
2025-08-26 19:16:41 -04:00

742 lines
18 KiB
Go

package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/kms/api-key-service/internal/audit"
"github.com/kms/api-key-service/internal/repository"
)
// AuditRepository implements the AuditRepository interface for PostgreSQL
type AuditRepository struct {
db repository.DatabaseProvider
}
// NewAuditRepository creates a new PostgreSQL audit repository
func NewAuditRepository(db repository.DatabaseProvider) repository.AuditRepository {
return &AuditRepository{db: db}
}
// Create stores a new audit event
func (r *AuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error {
query := `
INSERT INTO audit_events (
id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19
)
`
db := r.db.GetDB().(*sql.DB)
// Ensure event has an ID and timestamp
if event.ID == uuid.Nil {
event.ID = uuid.New()
}
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
// Convert details to JSON
var detailsJSON []byte
var err error
if event.Details != nil {
detailsJSON, err = json.Marshal(event.Details)
if err != nil {
return fmt.Errorf("failed to marshal event details: %w", err)
}
} else {
detailsJSON = []byte("{}")
}
// Convert metadata to JSON
var metadataJSON []byte
if event.Metadata != nil {
metadataJSON, err = json.Marshal(event.Metadata)
if err != nil {
return fmt.Errorf("failed to marshal event metadata: %w", err)
}
} else {
metadataJSON = []byte("{}")
}
// Handle nullable fields
var actorID, actorType, actorIP, userAgent *string
var tenantID *uuid.UUID
var resourceID, resourceType *string
var requestID, sessionID *string
if event.ActorID != "" {
actorID = &event.ActorID
}
if event.ActorType != "" {
actorType = &event.ActorType
}
if event.ActorIP != "" {
actorIP = &event.ActorIP
}
if event.UserAgent != "" {
userAgent = &event.UserAgent
}
if event.TenantID != nil {
tenantID = event.TenantID
}
if event.ResourceID != "" {
resourceID = &event.ResourceID
}
if event.ResourceType != "" {
resourceType = &event.ResourceType
}
if event.RequestID != "" {
requestID = &event.RequestID
}
if event.SessionID != "" {
sessionID = &event.SessionID
}
_, err = db.ExecContext(ctx, query,
event.ID,
string(event.Type),
string(event.Severity),
string(event.Status),
event.Timestamp,
actorID,
actorType,
actorIP,
userAgent,
tenantID,
resourceID,
resourceType,
event.Action,
event.Description,
string(detailsJSON),
requestID,
sessionID,
pq.Array(event.Tags),
string(metadataJSON),
)
if err != nil {
return fmt.Errorf("failed to create audit event: %w", err)
}
return nil
}
// Query retrieves audit events based on filter criteria
func (r *AuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) {
// Build dynamic query with filters
var conditions []string
var args []interface{}
argIndex := 1
baseQuery := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
`
// Add filters
if len(filter.EventTypes) > 0 {
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
typeStrings := make([]string, len(filter.EventTypes))
for i, t := range filter.EventTypes {
typeStrings[i] = string(t)
}
args = append(args, pq.Array(typeStrings))
argIndex++
}
if len(filter.Severities) > 0 {
conditions = append(conditions, fmt.Sprintf("severity = ANY($%d)", argIndex))
severityStrings := make([]string, len(filter.Severities))
for i, s := range filter.Severities {
severityStrings[i] = string(s)
}
args = append(args, pq.Array(severityStrings))
argIndex++
}
if len(filter.Statuses) > 0 {
conditions = append(conditions, fmt.Sprintf("status = ANY($%d)", argIndex))
statusStrings := make([]string, len(filter.Statuses))
for i, s := range filter.Statuses {
statusStrings[i] = string(s)
}
args = append(args, pq.Array(statusStrings))
argIndex++
}
if filter.ActorID != "" {
conditions = append(conditions, fmt.Sprintf("actor_id = $%d", argIndex))
args = append(args, filter.ActorID)
argIndex++
}
if filter.ActorType != "" {
conditions = append(conditions, fmt.Sprintf("actor_type = $%d", argIndex))
args = append(args, filter.ActorType)
argIndex++
}
if filter.TenantID != nil {
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
args = append(args, *filter.TenantID)
argIndex++
}
if filter.ResourceID != "" {
conditions = append(conditions, fmt.Sprintf("resource_id = $%d", argIndex))
args = append(args, filter.ResourceID)
argIndex++
}
if filter.ResourceType != "" {
conditions = append(conditions, fmt.Sprintf("resource_type = $%d", argIndex))
args = append(args, filter.ResourceType)
argIndex++
}
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *filter.StartTime)
argIndex++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *filter.EndTime)
argIndex++
}
if len(filter.Tags) > 0 {
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIndex))
args = append(args, pq.Array(filter.Tags))
argIndex++
}
// Build WHERE clause
if len(conditions) > 0 {
baseQuery += " WHERE " + strings.Join(conditions, " AND ")
}
// Add ORDER BY
orderBy := "timestamp"
if filter.OrderBy != "" {
switch filter.OrderBy {
case "timestamp", "type", "severity", "status":
orderBy = filter.OrderBy
}
}
direction := "DESC"
if !filter.OrderDesc {
direction = "ASC"
}
baseQuery += fmt.Sprintf(" ORDER BY %s %s", orderBy, direction)
// Add pagination
if filter.Limit <= 0 {
filter.Limit = 100
}
if filter.Limit > 1000 {
filter.Limit = 1000
}
baseQuery += fmt.Sprintf(" LIMIT $%d", argIndex)
args = append(args, filter.Limit)
argIndex++
if filter.Offset > 0 {
baseQuery += fmt.Sprintf(" OFFSET $%d", argIndex)
args = append(args, filter.Offset)
}
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, baseQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit events: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit events: %w", err)
}
return events, nil
}
// GetStats returns aggregated statistics for audit events
func (r *AuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) {
stats := &audit.AuditStats{
ByType: make(map[audit.EventType]int),
BySeverity: make(map[audit.EventSeverity]int),
ByStatus: make(map[audit.EventStatus]int),
}
// Build base conditions
var conditions []string
var args []interface{}
argIndex := 1
if len(filter.EventTypes) > 0 {
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
typeStrings := make([]string, len(filter.EventTypes))
for i, t := range filter.EventTypes {
typeStrings[i] = string(t)
}
args = append(args, pq.Array(typeStrings))
argIndex++
}
if filter.TenantID != nil {
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
args = append(args, *filter.TenantID)
argIndex++
}
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *filter.StartTime)
argIndex++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *filter.EndTime)
argIndex++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
db := r.db.GetDB().(*sql.DB)
// Get total count
totalQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
err := db.QueryRowContext(ctx, totalQuery, args...).Scan(&stats.TotalEvents)
if err != nil {
return nil, fmt.Errorf("failed to get total event count: %w", err)
}
// Get stats by type
typeQuery := fmt.Sprintf(`
SELECT type, COUNT(*)
FROM audit_events %s
GROUP BY type
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err := db.QueryContext(ctx, typeQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get type stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var eventType string
var count int
if err := rows.Scan(&eventType, &count); err != nil {
return nil, fmt.Errorf("failed to scan type stats: %w", err)
}
stats.ByType[audit.EventType(eventType)] = count
}
// Get stats by severity
severityQuery := fmt.Sprintf(`
SELECT severity, COUNT(*)
FROM audit_events %s
GROUP BY severity
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err = db.QueryContext(ctx, severityQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get severity stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var severity string
var count int
if err := rows.Scan(&severity, &count); err != nil {
return nil, fmt.Errorf("failed to scan severity stats: %w", err)
}
stats.BySeverity[audit.EventSeverity(severity)] = count
}
// Get stats by status
statusQuery := fmt.Sprintf(`
SELECT status, COUNT(*)
FROM audit_events %s
GROUP BY status
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err = db.QueryContext(ctx, statusQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get status stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
return nil, fmt.Errorf("failed to scan status stats: %w", err)
}
stats.ByStatus[audit.EventStatus(status)] = count
}
// Get time-based stats if requested
if filter.GroupBy != "" {
stats.ByTime = make(map[string]int)
var timeFormat string
switch filter.GroupBy {
case "hour":
timeFormat = "YYYY-MM-DD HH24:00"
case "day":
timeFormat = "YYYY-MM-DD"
default:
timeFormat = "YYYY-MM-DD"
}
timeQuery := fmt.Sprintf(`
SELECT TO_CHAR(timestamp, '%s') as time_group, COUNT(*)
FROM audit_events %s
GROUP BY time_group
ORDER BY time_group DESC
`, timeFormat, whereClause)
rows, err = db.QueryContext(ctx, timeQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get time stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var timeGroup string
var count int
if err := rows.Scan(&timeGroup, &count); err != nil {
return nil, fmt.Errorf("failed to scan time stats: %w", err)
}
stats.ByTime[timeGroup] = count
}
}
return stats, nil
}
// DeleteOldEvents removes audit events older than the specified time
func (r *AuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) {
query := `DELETE FROM audit_events WHERE timestamp < $1`
db := r.db.GetDB().(*sql.DB)
result, err := db.ExecContext(ctx, query, olderThan)
if err != nil {
return 0, fmt.Errorf("failed to delete old audit events: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// GetByID retrieves a specific audit event by its ID
func (r *AuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, eventID)
event, err := r.scanAuditEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("audit event with ID '%s' not found", eventID)
}
return nil, fmt.Errorf("failed to get audit event: %w", err)
}
return event, nil
}
// GetByRequestID retrieves all audit events for a specific request
func (r *AuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE request_id = $1
ORDER BY timestamp ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by request ID: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetBySession retrieves all audit events for a specific session
func (r *AuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE session_id = $1
ORDER BY timestamp ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by session ID: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetByActor retrieves audit events for a specific actor
func (r *AuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE actor_id = $1
ORDER BY timestamp DESC
LIMIT $2 OFFSET $3
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, actorID, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by actor: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetByResource retrieves audit events for a specific resource
func (r *AuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE resource_type = $1 AND resource_id = $2
ORDER BY timestamp DESC
LIMIT $3 OFFSET $4
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, resourceType, resourceID, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by resource: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// scanAuditEvent scans a database row into an AuditEvent struct
func (r *AuditRepository) scanAuditEvent(row interface{}) (*audit.AuditEvent, error) {
event := &audit.AuditEvent{}
var typeStr, severityStr, statusStr string
var actorID, actorType, actorIP, userAgent sql.NullString
var tenantID *uuid.UUID
var resourceID, resourceType sql.NullString
var detailsJSON, metadataJSON string
var requestID, sessionID sql.NullString
var tags pq.StringArray
var scanner interface {
Scan(dest ...interface{}) error
}
switch v := row.(type) {
case *sql.Row:
scanner = v
case *sql.Rows:
scanner = v
default:
return nil, fmt.Errorf("invalid row type")
}
err := scanner.Scan(
&event.ID,
&typeStr,
&severityStr,
&statusStr,
&event.Timestamp,
&actorID,
&actorType,
&actorIP,
&userAgent,
&tenantID,
&resourceID,
&resourceType,
&event.Action,
&event.Description,
&detailsJSON,
&requestID,
&sessionID,
&tags,
&metadataJSON,
)
if err != nil {
return nil, err
}
// Convert string enums to types
event.Type = audit.EventType(typeStr)
event.Severity = audit.EventSeverity(severityStr)
event.Status = audit.EventStatus(statusStr)
// Handle nullable fields
if actorID.Valid {
event.ActorID = actorID.String
}
if actorType.Valid {
event.ActorType = actorType.String
}
if actorIP.Valid {
event.ActorIP = actorIP.String
}
if userAgent.Valid {
event.UserAgent = userAgent.String
}
if tenantID != nil {
event.TenantID = tenantID
}
if resourceID.Valid {
event.ResourceID = resourceID.String
}
if resourceType.Valid {
event.ResourceType = resourceType.String
}
if requestID.Valid {
event.RequestID = requestID.String
}
if sessionID.Valid {
event.SessionID = sessionID.String
}
// Convert tags
event.Tags = []string(tags)
// Parse JSON fields
if detailsJSON != "" {
if err := json.Unmarshal([]byte(detailsJSON), &event.Details); err != nil {
return nil, fmt.Errorf("failed to unmarshal details JSON: %w", err)
}
}
if metadataJSON != "" {
if err := json.Unmarshal([]byte(metadataJSON), &event.Metadata); err != nil {
return nil, fmt.Errorf("failed to unmarshal metadata JSON: %w", err)
}
}
return event, nil
}