This commit is contained in:
2025-08-25 21:28:14 -04:00
parent efa2ee5b59
commit 19364fcc76
11 changed files with 1208 additions and 13 deletions

View File

@ -146,8 +146,8 @@ The project uses podman-compose for all testing environments and database operat
### End-to-End Testing
```bash
# Start test environment with podman-compose
podman-compose up -d
# Start test environment with podman-compose, guaranteeing that it updates with --build
podman-compose up -d --build
# Wait for services to be ready
sleep 10

View File

@ -63,9 +63,10 @@ func main() {
tokenRepo := postgres.NewStaticTokenRepository(db)
permRepo := postgres.NewPermissionRepository(db)
grantRepo := postgres.NewGrantedPermissionRepository(db)
auditRepo := postgres.NewAuditRepository(db)
// Initialize services
appService := services.NewApplicationService(appRepo, logger)
appService := services.NewApplicationService(appRepo, auditRepo, logger)
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, cfg.GetString("INTERNAL_HMAC_KEY"), cfg, logger)
authService := services.NewAuthenticationService(cfg, logger, permRepo)

View File

@ -55,8 +55,29 @@ func (h *ApplicationHandler) Create(c *gin.Context) {
return
}
// Validate input
validationErrors := h.validator.ValidateApplicationRequest(req.AppID, req.AppLink, req.CallbackURL, []string{})
// Validate input (skip permissions validation for application creation)
var validationErrors []validation.ValidationError
// Validate app ID
if result := h.validator.ValidateAppID(req.AppID); !result.Valid {
validationErrors = append(validationErrors, result.Errors...)
}
// Validate app link URL
if result := h.validator.ValidateURL(req.AppLink, "app_link"); !result.Valid {
validationErrors = append(validationErrors, result.Errors...)
}
// Validate callback URL
if result := h.validator.ValidateURL(req.CallbackURL, "callback_url"); !result.Valid {
validationErrors = append(validationErrors, result.Errors...)
}
// Validate token prefix if provided
if result := h.validator.ValidateTokenPrefix(req.TokenPrefix); !result.Valid {
validationErrors = append(validationErrors, result.Errors...)
}
if len(validationErrors) > 0 {
h.logger.Warn("Application validation failed",
zap.String("user_id", userID),

View File

@ -5,6 +5,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/kms/api-key-service/internal/audit"
"github.com/kms/api-key-service/internal/domain"
)
@ -319,3 +320,33 @@ type MetricsProvider interface {
// RecordDuration records the duration of an operation
RecordDuration(ctx context.Context, name string, duration time.Duration, labels map[string]string)
}
// AuditRepository defines the interface for audit event storage operations
type AuditRepository interface {
// Create stores a new audit event
Create(ctx context.Context, event *audit.AuditEvent) error
// Query retrieves audit events based on filter criteria
Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error)
// GetStats returns aggregated statistics for audit events
GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error)
// DeleteOldEvents removes audit events older than the specified time
DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error)
// GetByID retrieves a specific audit event by its ID
GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error)
// GetByRequestID retrieves all audit events for a specific request
GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error)
// GetBySession retrieves all audit events for a specific session
GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error)
// GetByActor retrieves audit events for a specific actor
GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error)
// GetByResource retrieves audit events for a specific resource
GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error)
}

View File

@ -0,0 +1,742 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/kms/api-key-service/internal/audit"
"github.com/kms/api-key-service/internal/repository"
)
// AuditRepository implements the AuditRepository interface for PostgreSQL
type AuditRepository struct {
db repository.DatabaseProvider
}
// NewAuditRepository creates a new PostgreSQL audit repository
func NewAuditRepository(db repository.DatabaseProvider) repository.AuditRepository {
return &AuditRepository{db: db}
}
// Create stores a new audit event
func (r *AuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error {
query := `
INSERT INTO audit_events (
id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19
)
`
db := r.db.GetDB().(*sql.DB)
// Ensure event has an ID and timestamp
if event.ID == uuid.Nil {
event.ID = uuid.New()
}
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
// Convert details to JSON
var detailsJSON []byte
var err error
if event.Details != nil {
detailsJSON, err = json.Marshal(event.Details)
if err != nil {
return fmt.Errorf("failed to marshal event details: %w", err)
}
} else {
detailsJSON = []byte("{}")
}
// Convert metadata to JSON
var metadataJSON []byte
if event.Metadata != nil {
metadataJSON, err = json.Marshal(event.Metadata)
if err != nil {
return fmt.Errorf("failed to marshal event metadata: %w", err)
}
} else {
metadataJSON = []byte("{}")
}
// Handle nullable fields
var actorID, actorType, actorIP, userAgent *string
var tenantID *uuid.UUID
var resourceID, resourceType *string
var requestID, sessionID *string
if event.ActorID != "" {
actorID = &event.ActorID
}
if event.ActorType != "" {
actorType = &event.ActorType
}
if event.ActorIP != "" {
actorIP = &event.ActorIP
}
if event.UserAgent != "" {
userAgent = &event.UserAgent
}
if event.TenantID != nil {
tenantID = event.TenantID
}
if event.ResourceID != "" {
resourceID = &event.ResourceID
}
if event.ResourceType != "" {
resourceType = &event.ResourceType
}
if event.RequestID != "" {
requestID = &event.RequestID
}
if event.SessionID != "" {
sessionID = &event.SessionID
}
_, err = db.ExecContext(ctx, query,
event.ID,
string(event.Type),
string(event.Severity),
string(event.Status),
event.Timestamp,
actorID,
actorType,
actorIP,
userAgent,
tenantID,
resourceID,
resourceType,
event.Action,
event.Description,
string(detailsJSON),
requestID,
sessionID,
pq.Array(event.Tags),
string(metadataJSON),
)
if err != nil {
return fmt.Errorf("failed to create audit event: %w", err)
}
return nil
}
// Query retrieves audit events based on filter criteria
func (r *AuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) {
// Build dynamic query with filters
var conditions []string
var args []interface{}
argIndex := 1
baseQuery := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
`
// Add filters
if len(filter.EventTypes) > 0 {
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
typeStrings := make([]string, len(filter.EventTypes))
for i, t := range filter.EventTypes {
typeStrings[i] = string(t)
}
args = append(args, pq.Array(typeStrings))
argIndex++
}
if len(filter.Severities) > 0 {
conditions = append(conditions, fmt.Sprintf("severity = ANY($%d)", argIndex))
severityStrings := make([]string, len(filter.Severities))
for i, s := range filter.Severities {
severityStrings[i] = string(s)
}
args = append(args, pq.Array(severityStrings))
argIndex++
}
if len(filter.Statuses) > 0 {
conditions = append(conditions, fmt.Sprintf("status = ANY($%d)", argIndex))
statusStrings := make([]string, len(filter.Statuses))
for i, s := range filter.Statuses {
statusStrings[i] = string(s)
}
args = append(args, pq.Array(statusStrings))
argIndex++
}
if filter.ActorID != "" {
conditions = append(conditions, fmt.Sprintf("actor_id = $%d", argIndex))
args = append(args, filter.ActorID)
argIndex++
}
if filter.ActorType != "" {
conditions = append(conditions, fmt.Sprintf("actor_type = $%d", argIndex))
args = append(args, filter.ActorType)
argIndex++
}
if filter.TenantID != nil {
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
args = append(args, *filter.TenantID)
argIndex++
}
if filter.ResourceID != "" {
conditions = append(conditions, fmt.Sprintf("resource_id = $%d", argIndex))
args = append(args, filter.ResourceID)
argIndex++
}
if filter.ResourceType != "" {
conditions = append(conditions, fmt.Sprintf("resource_type = $%d", argIndex))
args = append(args, filter.ResourceType)
argIndex++
}
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *filter.StartTime)
argIndex++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *filter.EndTime)
argIndex++
}
if len(filter.Tags) > 0 {
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIndex))
args = append(args, pq.Array(filter.Tags))
argIndex++
}
// Build WHERE clause
if len(conditions) > 0 {
baseQuery += " WHERE " + strings.Join(conditions, " AND ")
}
// Add ORDER BY
orderBy := "timestamp"
if filter.OrderBy != "" {
switch filter.OrderBy {
case "timestamp", "type", "severity", "status":
orderBy = filter.OrderBy
}
}
direction := "DESC"
if !filter.OrderDesc {
direction = "ASC"
}
baseQuery += fmt.Sprintf(" ORDER BY %s %s", orderBy, direction)
// Add pagination
if filter.Limit <= 0 {
filter.Limit = 100
}
if filter.Limit > 1000 {
filter.Limit = 1000
}
baseQuery += fmt.Sprintf(" LIMIT $%d", argIndex)
args = append(args, filter.Limit)
argIndex++
if filter.Offset > 0 {
baseQuery += fmt.Sprintf(" OFFSET $%d", argIndex)
args = append(args, filter.Offset)
}
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, baseQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit events: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit events: %w", err)
}
return events, nil
}
// GetStats returns aggregated statistics for audit events
func (r *AuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) {
stats := &audit.AuditStats{
ByType: make(map[audit.EventType]int),
BySeverity: make(map[audit.EventSeverity]int),
ByStatus: make(map[audit.EventStatus]int),
}
// Build base conditions
var conditions []string
var args []interface{}
argIndex := 1
if len(filter.EventTypes) > 0 {
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
typeStrings := make([]string, len(filter.EventTypes))
for i, t := range filter.EventTypes {
typeStrings[i] = string(t)
}
args = append(args, pq.Array(typeStrings))
argIndex++
}
if filter.TenantID != nil {
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
args = append(args, *filter.TenantID)
argIndex++
}
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *filter.StartTime)
argIndex++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *filter.EndTime)
argIndex++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
db := r.db.GetDB().(*sql.DB)
// Get total count
totalQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
err := db.QueryRowContext(ctx, totalQuery, args...).Scan(&stats.TotalEvents)
if err != nil {
return nil, fmt.Errorf("failed to get total event count: %w", err)
}
// Get stats by type
typeQuery := fmt.Sprintf(`
SELECT type, COUNT(*)
FROM audit_events %s
GROUP BY type
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err := db.QueryContext(ctx, typeQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get type stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var eventType string
var count int
if err := rows.Scan(&eventType, &count); err != nil {
return nil, fmt.Errorf("failed to scan type stats: %w", err)
}
stats.ByType[audit.EventType(eventType)] = count
}
// Get stats by severity
severityQuery := fmt.Sprintf(`
SELECT severity, COUNT(*)
FROM audit_events %s
GROUP BY severity
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err = db.QueryContext(ctx, severityQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get severity stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var severity string
var count int
if err := rows.Scan(&severity, &count); err != nil {
return nil, fmt.Errorf("failed to scan severity stats: %w", err)
}
stats.BySeverity[audit.EventSeverity(severity)] = count
}
// Get stats by status
statusQuery := fmt.Sprintf(`
SELECT status, COUNT(*)
FROM audit_events %s
GROUP BY status
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err = db.QueryContext(ctx, statusQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get status stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
return nil, fmt.Errorf("failed to scan status stats: %w", err)
}
stats.ByStatus[audit.EventStatus(status)] = count
}
// Get time-based stats if requested
if filter.GroupBy != "" {
stats.ByTime = make(map[string]int)
var timeFormat string
switch filter.GroupBy {
case "hour":
timeFormat = "YYYY-MM-DD HH24:00"
case "day":
timeFormat = "YYYY-MM-DD"
default:
timeFormat = "YYYY-MM-DD"
}
timeQuery := fmt.Sprintf(`
SELECT TO_CHAR(timestamp, '%s') as time_group, COUNT(*)
FROM audit_events %s
GROUP BY time_group
ORDER BY time_group DESC
`, timeFormat, whereClause)
rows, err = db.QueryContext(ctx, timeQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get time stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var timeGroup string
var count int
if err := rows.Scan(&timeGroup, &count); err != nil {
return nil, fmt.Errorf("failed to scan time stats: %w", err)
}
stats.ByTime[timeGroup] = count
}
}
return stats, nil
}
// DeleteOldEvents removes audit events older than the specified time
func (r *AuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) {
query := `DELETE FROM audit_events WHERE timestamp < $1`
db := r.db.GetDB().(*sql.DB)
result, err := db.ExecContext(ctx, query, olderThan)
if err != nil {
return 0, fmt.Errorf("failed to delete old audit events: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// GetByID retrieves a specific audit event by its ID
func (r *AuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, eventID)
event, err := r.scanAuditEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("audit event with ID '%s' not found", eventID)
}
return nil, fmt.Errorf("failed to get audit event: %w", err)
}
return event, nil
}
// GetByRequestID retrieves all audit events for a specific request
func (r *AuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE request_id = $1
ORDER BY timestamp ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by request ID: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetBySession retrieves all audit events for a specific session
func (r *AuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE session_id = $1
ORDER BY timestamp ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by session ID: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetByActor retrieves audit events for a specific actor
func (r *AuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE actor_id = $1
ORDER BY timestamp DESC
LIMIT $2 OFFSET $3
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, actorID, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by actor: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetByResource retrieves audit events for a specific resource
func (r *AuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE resource_type = $1 AND resource_id = $2
ORDER BY timestamp DESC
LIMIT $3 OFFSET $4
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, resourceType, resourceID, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by resource: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// scanAuditEvent scans a database row into an AuditEvent struct
func (r *AuditRepository) scanAuditEvent(row interface{}) (*audit.AuditEvent, error) {
event := &audit.AuditEvent{}
var typeStr, severityStr, statusStr string
var actorID, actorType, actorIP, userAgent sql.NullString
var tenantID *uuid.UUID
var resourceID, resourceType sql.NullString
var detailsJSON, metadataJSON string
var requestID, sessionID sql.NullString
var tags pq.StringArray
var scanner interface {
Scan(dest ...interface{}) error
}
switch v := row.(type) {
case *sql.Row:
scanner = v
case *sql.Rows:
scanner = v
default:
return nil, fmt.Errorf("invalid row type")
}
err := scanner.Scan(
&event.ID,
&typeStr,
&severityStr,
&statusStr,
&event.Timestamp,
&actorID,
&actorType,
&actorIP,
&userAgent,
&tenantID,
&resourceID,
&resourceType,
&event.Action,
&event.Description,
&detailsJSON,
&requestID,
&sessionID,
&tags,
&metadataJSON,
)
if err != nil {
return nil, err
}
// Convert string enums to types
event.Type = audit.EventType(typeStr)
event.Severity = audit.EventSeverity(severityStr)
event.Status = audit.EventStatus(statusStr)
// Handle nullable fields
if actorID.Valid {
event.ActorID = actorID.String
}
if actorType.Valid {
event.ActorType = actorType.String
}
if actorIP.Valid {
event.ActorIP = actorIP.String
}
if userAgent.Valid {
event.UserAgent = userAgent.String
}
if tenantID != nil {
event.TenantID = tenantID
}
if resourceID.Valid {
event.ResourceID = resourceID.String
}
if resourceType.Valid {
event.ResourceType = resourceType.String
}
if requestID.Valid {
event.RequestID = requestID.String
}
if sessionID.Valid {
event.SessionID = sessionID.String
}
// Convert tags
event.Tags = []string(tags)
// Parse JSON fields
if detailsJSON != "" {
if err := json.Unmarshal([]byte(detailsJSON), &event.Details); err != nil {
return nil, fmt.Errorf("failed to unmarshal details JSON: %w", err)
}
}
if metadataJSON != "" {
if err := json.Unmarshal([]byte(metadataJSON), &event.Metadata); err != nil {
return nil, fmt.Errorf("failed to unmarshal metadata JSON: %w", err)
}
}
return event, nil
}

View File

@ -5,30 +5,61 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/go-playground/validator/v10"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/audit"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
)
// applicationService implements the ApplicationService interface
type applicationService struct {
appRepo repository.ApplicationRepository
logger *zap.Logger
validator *validator.Validate
appRepo repository.ApplicationRepository
auditRepo repository.AuditRepository
auditLogger audit.AuditLogger
logger *zap.Logger
validator *validator.Validate
}
// NewApplicationService creates a new application service
func NewApplicationService(appRepo repository.ApplicationRepository, logger *zap.Logger) ApplicationService {
func NewApplicationService(appRepo repository.ApplicationRepository, auditRepo repository.AuditRepository, logger *zap.Logger) ApplicationService {
// Create audit logger with audit package's repository interface
auditRepoImpl := &auditRepositoryAdapter{repo: auditRepo}
auditLogger := audit.NewAuditLogger(nil, logger, auditRepoImpl) // config can be nil for now
return &applicationService{
appRepo: appRepo,
logger: logger,
validator: validator.New(),
appRepo: appRepo,
auditRepo: auditRepo,
auditLogger: auditLogger,
logger: logger,
validator: validator.New(),
}
}
// auditRepositoryAdapter adapts repository.AuditRepository to audit.AuditRepository
type auditRepositoryAdapter struct {
repo repository.AuditRepository
}
func (a *auditRepositoryAdapter) Create(ctx context.Context, event *audit.AuditEvent) error {
return a.repo.Create(ctx, event)
}
func (a *auditRepositoryAdapter) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) {
return a.repo.Query(ctx, filter)
}
func (a *auditRepositoryAdapter) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) {
return a.repo.GetStats(ctx, filter)
}
func (a *auditRepositoryAdapter) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) {
return a.repo.DeleteOldEvents(ctx, olderThan)
}
// Create creates a new application
func (s *applicationService) Create(ctx context.Context, req *domain.CreateApplicationRequest, userID string) (*domain.Application, error) {
s.logger.Info("Creating application", zap.String("app_id", req.AppID), zap.String("user_id", userID))
@ -75,9 +106,43 @@ func (s *applicationService) Create(ctx context.Context, req *domain.CreateAppli
if err := s.appRepo.Create(ctx, app); err != nil {
s.logger.Error("Failed to create application", zap.Error(err), zap.String("app_id", req.AppID))
// Log audit event for failed creation
s.auditLogger.LogEvent(ctx, audit.NewAuditEventBuilder(audit.EventTypeAppCreated).
WithSeverity(audit.SeverityError).
WithStatus(audit.StatusFailure).
WithActor(userID, "user", "").
WithResource(req.AppID, "application").
WithAction("create").
WithDescription(fmt.Sprintf("Failed to create application %s", req.AppID)).
WithDetails(map[string]interface{}{
"error": err.Error(),
"app_id": req.AppID,
"user_id": userID,
}).
Build())
return nil, fmt.Errorf("failed to create application: %w", err)
}
// Log successful creation
s.auditLogger.LogEvent(ctx, audit.NewAuditEventBuilder(audit.EventTypeAppCreated).
WithSeverity(audit.SeverityInfo).
WithStatus(audit.StatusSuccess).
WithActor(userID, "user", "").
WithResource(app.AppID, "application").
WithAction("create").
WithDescription(fmt.Sprintf("Created application %s", app.AppID)).
WithDetails(map[string]interface{}{
"app_id": app.AppID,
"app_link": app.AppLink,
"type": app.Type,
"user_id": userID,
"owner_name": app.Owner.Name,
"owner_type": app.Owner.Type,
}).
Build())
s.logger.Info("Application created successfully", zap.String("app_id", app.AppID))
return app, nil
}

View File

@ -0,0 +1,27 @@
-- Migration: 004_add_audit_events (down)
-- Remove audit_events table and related objects
-- Drop the cleanup function
DROP FUNCTION IF EXISTS cleanup_old_audit_events(INTEGER);
-- Drop indexes first (they will be dropped automatically with the table, but explicit for clarity)
DROP INDEX IF EXISTS idx_audit_events_timestamp;
DROP INDEX IF EXISTS idx_audit_events_type;
DROP INDEX IF EXISTS idx_audit_events_severity;
DROP INDEX IF EXISTS idx_audit_events_status;
DROP INDEX IF EXISTS idx_audit_events_actor_id;
DROP INDEX IF EXISTS idx_audit_events_actor_type;
DROP INDEX IF EXISTS idx_audit_events_tenant_id;
DROP INDEX IF EXISTS idx_audit_events_resource;
DROP INDEX IF EXISTS idx_audit_events_request_id;
DROP INDEX IF EXISTS idx_audit_events_session_id;
DROP INDEX IF EXISTS idx_audit_events_details;
DROP INDEX IF EXISTS idx_audit_events_metadata;
DROP INDEX IF EXISTS idx_audit_events_tags;
DROP INDEX IF EXISTS idx_audit_events_actor_timestamp;
DROP INDEX IF EXISTS idx_audit_events_type_timestamp;
DROP INDEX IF EXISTS idx_audit_events_tenant_timestamp;
DROP INDEX IF EXISTS idx_audit_events_resource_timestamp;
-- Drop the audit_events table
DROP TABLE IF EXISTS audit_events;

View File

@ -0,0 +1,102 @@
-- Migration: 004_add_audit_events
-- Add audit_events table for comprehensive audit logging
-- Create audit_events table
CREATE TABLE IF NOT EXISTS audit_events (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
type VARCHAR(50) NOT NULL,
severity VARCHAR(20) NOT NULL CHECK (severity IN ('info', 'warning', 'error', 'critical')),
status VARCHAR(20) NOT NULL CHECK (status IN ('success', 'failure', 'pending')),
timestamp TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
-- Actor information
actor_id VARCHAR(255),
actor_type VARCHAR(50) CHECK (actor_type IN ('user', 'system', 'service')),
actor_ip INET,
user_agent TEXT,
-- Tenant information (for multi-tenancy support)
tenant_id UUID,
-- Resource information
resource_id VARCHAR(255),
resource_type VARCHAR(100),
-- Event details
action VARCHAR(100) NOT NULL,
description TEXT NOT NULL,
details JSONB DEFAULT '{}',
-- Request context
request_id VARCHAR(100),
session_id VARCHAR(255),
-- Additional metadata
tags TEXT[],
metadata JSONB DEFAULT '{}'
);
-- Create indexes for efficient querying
CREATE INDEX IF NOT EXISTS idx_audit_events_timestamp ON audit_events(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_audit_events_type ON audit_events(type);
CREATE INDEX IF NOT EXISTS idx_audit_events_severity ON audit_events(severity);
CREATE INDEX IF NOT EXISTS idx_audit_events_status ON audit_events(status);
CREATE INDEX IF NOT EXISTS idx_audit_events_actor_id ON audit_events(actor_id);
CREATE INDEX IF NOT EXISTS idx_audit_events_actor_type ON audit_events(actor_type);
CREATE INDEX IF NOT EXISTS idx_audit_events_tenant_id ON audit_events(tenant_id) WHERE tenant_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_audit_events_resource ON audit_events(resource_type, resource_id) WHERE resource_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_audit_events_request_id ON audit_events(request_id) WHERE request_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_audit_events_session_id ON audit_events(session_id) WHERE session_id IS NOT NULL;
-- GIN indexes for JSONB columns
CREATE INDEX IF NOT EXISTS idx_audit_events_details ON audit_events USING GIN (details);
CREATE INDEX IF NOT EXISTS idx_audit_events_metadata ON audit_events USING GIN (metadata);
-- GIN index for tags array
CREATE INDEX IF NOT EXISTS idx_audit_events_tags ON audit_events USING GIN (tags);
-- Composite indexes for common query patterns
CREATE INDEX IF NOT EXISTS idx_audit_events_actor_timestamp ON audit_events(actor_id, timestamp DESC) WHERE actor_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_audit_events_type_timestamp ON audit_events(type, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_audit_events_tenant_timestamp ON audit_events(tenant_id, timestamp DESC) WHERE tenant_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_audit_events_resource_timestamp ON audit_events(resource_type, resource_id, timestamp DESC) WHERE resource_id IS NOT NULL;
-- Add comments for documentation
COMMENT ON TABLE audit_events IS 'Comprehensive audit log for all system events and user actions';
COMMENT ON COLUMN audit_events.id IS 'Unique event identifier';
COMMENT ON COLUMN audit_events.type IS 'Event type (e.g., auth.login, app.created)';
COMMENT ON COLUMN audit_events.severity IS 'Event severity level: info, warning, error, critical';
COMMENT ON COLUMN audit_events.status IS 'Event status: success, failure, pending';
COMMENT ON COLUMN audit_events.timestamp IS 'When the event occurred';
COMMENT ON COLUMN audit_events.actor_id IS 'ID of the user/system that triggered the event';
COMMENT ON COLUMN audit_events.actor_type IS 'Type of actor: user, system, service';
COMMENT ON COLUMN audit_events.actor_ip IS 'IP address of the actor';
COMMENT ON COLUMN audit_events.user_agent IS 'User agent string (for HTTP requests)';
COMMENT ON COLUMN audit_events.tenant_id IS 'Tenant ID for multi-tenant environments';
COMMENT ON COLUMN audit_events.resource_id IS 'ID of the resource being acted upon';
COMMENT ON COLUMN audit_events.resource_type IS 'Type of resource (e.g., application, token)';
COMMENT ON COLUMN audit_events.action IS 'Action performed';
COMMENT ON COLUMN audit_events.description IS 'Human-readable description of the event';
COMMENT ON COLUMN audit_events.details IS 'Additional structured details as JSON';
COMMENT ON COLUMN audit_events.request_id IS 'Request ID for tracing';
COMMENT ON COLUMN audit_events.session_id IS 'Session ID for user session tracking';
COMMENT ON COLUMN audit_events.tags IS 'Array of tags for categorization';
COMMENT ON COLUMN audit_events.metadata IS 'Additional metadata as JSON';
-- Create a function to automatically clean up old audit events (optional)
CREATE OR REPLACE FUNCTION cleanup_old_audit_events(retention_days INTEGER DEFAULT 365)
RETURNS INTEGER AS $$
DECLARE
deleted_count INTEGER;
BEGIN
-- Delete audit events older than retention period
DELETE FROM audit_events
WHERE timestamp < NOW() - (retention_days || ' days')::INTERVAL;
GET DIAGNOSTICS deleted_count = ROW_COUNT;
RETURN deleted_count;
END;
$$ LANGUAGE plpgsql;
COMMENT ON FUNCTION cleanup_old_audit_events(INTEGER) IS 'Function to clean up audit events older than specified days (default: 365 days)';

View File

@ -189,6 +189,7 @@ test_application_endpoints() {
"app_link": "https://example.com/test-app",
"type": ["static"],
"callback_url": "https://example.com/callback",
"token_prefix": "TEST",
"token_renewal_duration": 604800000000000,
"max_token_duration": 2592000000000000,
"owner": {

View File

@ -96,8 +96,11 @@ func (suite *IntegrationTestSuite) setupServer() {
// Create a no-op logger for tests
logger := zap.NewNop()
// Initialize repositories
auditRepo := NewMockAuditRepository()
// Initialize services
appService := services.NewApplicationService(appRepo, logger)
appService := services.NewApplicationService(appRepo, auditRepo, logger)
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, suite.cfg.GetString("INTERNAL_HMAC_KEY"), suite.cfg, logger)
authService := services.NewAuthenticationService(suite.cfg, logger, permRepo)

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/kms/api-key-service/internal/audit"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
)
@ -612,3 +613,204 @@ func (m *MockGrantedPermissionRepository) HasAnyPermission(ctx context.Context,
return result, nil
}
// MockAuditRepository implements AuditRepository for testing
type MockAuditRepository struct {
mu sync.RWMutex
events []*audit.AuditEvent
}
func NewMockAuditRepository() repository.AuditRepository {
return &MockAuditRepository{
events: make([]*audit.AuditEvent, 0),
}
}
func (m *MockAuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error {
m.mu.Lock()
defer m.mu.Unlock()
if event.ID == uuid.Nil {
event.ID = uuid.New()
}
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
m.events = append(m.events, event)
return nil
}
func (m *MockAuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var result []*audit.AuditEvent
for _, event := range m.events {
// Simple filtering logic for testing
if len(filter.EventTypes) > 0 {
found := false
for _, t := range filter.EventTypes {
if event.Type == t {
found = true
break
}
}
if !found {
continue
}
}
if filter.ActorID != "" && event.ActorID != filter.ActorID {
continue
}
if filter.ResourceID != "" && event.ResourceID != filter.ResourceID {
continue
}
if filter.ResourceType != "" && event.ResourceType != filter.ResourceType {
continue
}
result = append(result, event)
}
// Apply pagination
if filter.Offset >= len(result) {
return []*audit.AuditEvent{}, nil
}
end := filter.Offset + filter.Limit
if end > len(result) {
end = len(result)
}
return result[filter.Offset:end], nil
}
func (m *MockAuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
stats := &audit.AuditStats{
TotalEvents: len(m.events),
ByType: make(map[audit.EventType]int),
BySeverity: make(map[audit.EventSeverity]int),
ByStatus: make(map[audit.EventStatus]int),
}
for _, event := range m.events {
stats.ByType[event.Type]++
stats.BySeverity[event.Severity]++
stats.ByStatus[event.Status]++
}
return stats, nil
}
func (m *MockAuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
var kept []*audit.AuditEvent
deleted := 0
for _, event := range m.events {
if event.Timestamp.Before(olderThan) {
deleted++
} else {
kept = append(kept, event)
}
}
m.events = kept
return deleted, nil
}
func (m *MockAuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, event := range m.events {
if event.ID == eventID {
return event, nil
}
}
return nil, fmt.Errorf("audit event with ID '%s' not found", eventID)
}
func (m *MockAuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var result []*audit.AuditEvent
for _, event := range m.events {
if event.RequestID == requestID {
result = append(result, event)
}
}
return result, nil
}
func (m *MockAuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var result []*audit.AuditEvent
for _, event := range m.events {
if event.SessionID == sessionID {
result = append(result, event)
}
}
return result, nil
}
func (m *MockAuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var matching []*audit.AuditEvent
for _, event := range m.events {
if event.ActorID == actorID {
matching = append(matching, event)
}
}
if offset >= len(matching) {
return []*audit.AuditEvent{}, nil
}
end := offset + limit
if end > len(matching) {
end = len(matching)
}
return matching[offset:end], nil
}
func (m *MockAuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var matching []*audit.AuditEvent
for _, event := range m.events {
if event.ResourceType == resourceType && event.ResourceID == resourceID {
matching = append(matching, event)
}
}
if offset >= len(matching) {
return []*audit.AuditEvent{}, nil
}
end := offset + limit
if end > len(matching) {
end = len(matching)
}
return matching[offset:end], nil
}