This commit is contained in:
2025-08-26 19:16:41 -04:00
parent 7ca61eb712
commit 6725529b01
113 changed files with 0 additions and 337 deletions

View File

@ -0,0 +1,387 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/lib/pq"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
)
// ApplicationRepository implements the ApplicationRepository interface for PostgreSQL
type ApplicationRepository struct {
db repository.DatabaseProvider
}
// NewApplicationRepository creates a new PostgreSQL application repository
func NewApplicationRepository(db repository.DatabaseProvider) repository.ApplicationRepository {
return &ApplicationRepository{db: db}
}
// Create creates a new application
func (r *ApplicationRepository) Create(ctx context.Context, app *domain.Application) error {
query := `
INSERT INTO applications (
app_id, app_link, type, callback_url, hmac_key, token_prefix,
token_renewal_duration, max_token_duration,
owner_type, owner_name, owner_owner,
created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
`
db := r.db.GetDB().(*sql.DB)
now := time.Now()
// Convert application types to string array
typeStrings := make([]string, len(app.Type))
for i, t := range app.Type {
typeStrings[i] = string(t)
}
_, err := db.ExecContext(ctx, query,
app.AppID,
app.AppLink,
pq.Array(typeStrings),
app.CallbackURL,
app.HMACKey,
app.TokenPrefix,
app.TokenRenewalDuration.Duration.Nanoseconds(),
app.MaxTokenDuration.Duration.Nanoseconds(),
string(app.Owner.Type),
app.Owner.Name,
app.Owner.Owner,
now,
now,
)
if err != nil {
if isUniqueViolation(err) {
return fmt.Errorf("application with ID '%s' already exists", app.AppID)
}
return fmt.Errorf("failed to create application: %w", err)
}
app.CreatedAt = now
app.UpdatedAt = now
return nil
}
// GetByID retrieves an application by its ID
func (r *ApplicationRepository) GetByID(ctx context.Context, appID string) (*domain.Application, error) {
query := `
SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix,
token_renewal_duration, max_token_duration,
owner_type, owner_name, owner_owner,
created_at, updated_at
FROM applications
WHERE app_id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, appID)
app := &domain.Application{}
var typeStrings pq.StringArray
var tokenRenewalNanos, maxTokenNanos int64
var ownerType string
err := row.Scan(
&app.AppID,
&app.AppLink,
&typeStrings,
&app.CallbackURL,
&app.HMACKey,
&app.TokenPrefix,
&tokenRenewalNanos,
&maxTokenNanos,
&ownerType,
&app.Owner.Name,
&app.Owner.Owner,
&app.CreatedAt,
&app.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("application with ID '%s' not found", appID)
}
return nil, fmt.Errorf("failed to get application: %w", err)
}
// Convert string array to application types
app.Type = make([]domain.ApplicationType, len(typeStrings))
for i, t := range typeStrings {
app.Type[i] = domain.ApplicationType(t)
}
// Convert nanoseconds to duration
app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)}
app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)}
// Convert owner type
app.Owner.Type = domain.OwnerType(ownerType)
return app, nil
}
// List retrieves applications with pagination
func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]*domain.Application, error) {
query := `
SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix,
token_renewal_duration, max_token_duration,
owner_type, owner_name, owner_owner,
created_at, updated_at
FROM applications
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to list applications: %w", err)
}
defer rows.Close()
var applications []*domain.Application
for rows.Next() {
app := &domain.Application{}
var typeStrings pq.StringArray
var tokenRenewalNanos, maxTokenNanos int64
var ownerType string
err := rows.Scan(
&app.AppID,
&app.AppLink,
&typeStrings,
&app.CallbackURL,
&app.HMACKey,
&app.TokenPrefix,
&tokenRenewalNanos,
&maxTokenNanos,
&ownerType,
&app.Owner.Name,
&app.Owner.Owner,
&app.CreatedAt,
&app.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan application: %w", err)
}
// Convert string array to application types
app.Type = make([]domain.ApplicationType, len(typeStrings))
for i, t := range typeStrings {
app.Type[i] = domain.ApplicationType(t)
}
// Convert nanoseconds to duration
app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)}
app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)}
// Convert owner type
app.Owner.Type = domain.OwnerType(ownerType)
applications = append(applications, app)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("failed to iterate applications: %w", err)
}
return applications, nil
}
// Update updates an existing application
func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) {
// Build secure dynamic update query using a whitelist approach
var setParts []string
var args []interface{}
argIndex := 1
// Whitelist of allowed fields to prevent SQL injection
allowedFields := map[string]string{
"app_link": "app_link",
"type": "type",
"callback_url": "callback_url",
"hmac_key": "hmac_key",
"token_prefix": "token_prefix",
"token_renewal_duration": "token_renewal_duration",
"max_token_duration": "max_token_duration",
"owner_type": "owner_type",
"owner_name": "owner_name",
"owner_owner": "owner_owner",
}
if updates.AppLink != nil {
if field, ok := allowedFields["app_link"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.AppLink)
argIndex++
}
}
if updates.Type != nil {
if field, ok := allowedFields["type"]; ok {
typeStrings := make([]string, len(*updates.Type))
for i, t := range *updates.Type {
typeStrings[i] = string(t)
}
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, pq.Array(typeStrings))
argIndex++
}
}
if updates.CallbackURL != nil {
if field, ok := allowedFields["callback_url"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.CallbackURL)
argIndex++
}
}
if updates.HMACKey != nil {
if field, ok := allowedFields["hmac_key"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.HMACKey)
argIndex++
}
}
if updates.TokenPrefix != nil {
if field, ok := allowedFields["token_prefix"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, *updates.TokenPrefix)
argIndex++
}
}
if updates.TokenRenewalDuration != nil {
if field, ok := allowedFields["token_renewal_duration"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
argIndex++
}
}
if updates.MaxTokenDuration != nil {
if field, ok := allowedFields["max_token_duration"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
argIndex++
}
}
if updates.Owner != nil {
if field, ok := allowedFields["owner_type"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, string(updates.Owner.Type))
argIndex++
}
if field, ok := allowedFields["owner_name"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.Owner.Name)
argIndex++
}
if field, ok := allowedFields["owner_owner"]; ok {
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
args = append(args, updates.Owner.Owner)
argIndex++
}
}
if len(setParts) == 0 {
return r.GetByID(ctx, appID) // No updates, return current state
}
// Always update the updated_at field - using literal field name for security
setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex))
args = append(args, time.Now())
argIndex++
// Add WHERE clause parameter
args = append(args, appID)
// Build the final query with properly parameterized placeholders
query := fmt.Sprintf(`
UPDATE applications
SET %s
WHERE app_id = $%d
`, strings.Join(setParts, ", "), argIndex)
db := r.db.GetDB().(*sql.DB)
result, err := db.ExecContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to update application: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return nil, fmt.Errorf("application with ID '%s' not found", appID)
}
// Return updated application
return r.GetByID(ctx, appID)
}
// Delete deletes an application
func (r *ApplicationRepository) Delete(ctx context.Context, appID string) error {
query := `DELETE FROM applications WHERE app_id = $1`
db := r.db.GetDB().(*sql.DB)
result, err := db.ExecContext(ctx, query, appID)
if err != nil {
return fmt.Errorf("failed to delete application: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("application with ID '%s' not found", appID)
}
return nil
}
// Exists checks if an application exists
func (r *ApplicationRepository) Exists(ctx context.Context, appID string) (bool, error) {
query := `SELECT 1 FROM applications WHERE app_id = $1`
db := r.db.GetDB().(*sql.DB)
var exists int
err := db.QueryRowContext(ctx, query, appID).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("failed to check application existence: %w", err)
}
return true, nil
}
// isUniqueViolation checks if the error is a unique constraint violation
func isUniqueViolation(err error) bool {
if pqErr, ok := err.(*pq.Error); ok {
return pqErr.Code == "23505" // unique_violation
}
return false
}

View File

@ -0,0 +1,742 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/kms/api-key-service/internal/audit"
"github.com/kms/api-key-service/internal/repository"
)
// AuditRepository implements the AuditRepository interface for PostgreSQL
type AuditRepository struct {
db repository.DatabaseProvider
}
// NewAuditRepository creates a new PostgreSQL audit repository
func NewAuditRepository(db repository.DatabaseProvider) repository.AuditRepository {
return &AuditRepository{db: db}
}
// Create stores a new audit event
func (r *AuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error {
query := `
INSERT INTO audit_events (
id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19
)
`
db := r.db.GetDB().(*sql.DB)
// Ensure event has an ID and timestamp
if event.ID == uuid.Nil {
event.ID = uuid.New()
}
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
// Convert details to JSON
var detailsJSON []byte
var err error
if event.Details != nil {
detailsJSON, err = json.Marshal(event.Details)
if err != nil {
return fmt.Errorf("failed to marshal event details: %w", err)
}
} else {
detailsJSON = []byte("{}")
}
// Convert metadata to JSON
var metadataJSON []byte
if event.Metadata != nil {
metadataJSON, err = json.Marshal(event.Metadata)
if err != nil {
return fmt.Errorf("failed to marshal event metadata: %w", err)
}
} else {
metadataJSON = []byte("{}")
}
// Handle nullable fields
var actorID, actorType, actorIP, userAgent *string
var tenantID *uuid.UUID
var resourceID, resourceType *string
var requestID, sessionID *string
if event.ActorID != "" {
actorID = &event.ActorID
}
if event.ActorType != "" {
actorType = &event.ActorType
}
if event.ActorIP != "" {
actorIP = &event.ActorIP
}
if event.UserAgent != "" {
userAgent = &event.UserAgent
}
if event.TenantID != nil {
tenantID = event.TenantID
}
if event.ResourceID != "" {
resourceID = &event.ResourceID
}
if event.ResourceType != "" {
resourceType = &event.ResourceType
}
if event.RequestID != "" {
requestID = &event.RequestID
}
if event.SessionID != "" {
sessionID = &event.SessionID
}
_, err = db.ExecContext(ctx, query,
event.ID,
string(event.Type),
string(event.Severity),
string(event.Status),
event.Timestamp,
actorID,
actorType,
actorIP,
userAgent,
tenantID,
resourceID,
resourceType,
event.Action,
event.Description,
string(detailsJSON),
requestID,
sessionID,
pq.Array(event.Tags),
string(metadataJSON),
)
if err != nil {
return fmt.Errorf("failed to create audit event: %w", err)
}
return nil
}
// Query retrieves audit events based on filter criteria
func (r *AuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) {
// Build dynamic query with filters
var conditions []string
var args []interface{}
argIndex := 1
baseQuery := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
`
// Add filters
if len(filter.EventTypes) > 0 {
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
typeStrings := make([]string, len(filter.EventTypes))
for i, t := range filter.EventTypes {
typeStrings[i] = string(t)
}
args = append(args, pq.Array(typeStrings))
argIndex++
}
if len(filter.Severities) > 0 {
conditions = append(conditions, fmt.Sprintf("severity = ANY($%d)", argIndex))
severityStrings := make([]string, len(filter.Severities))
for i, s := range filter.Severities {
severityStrings[i] = string(s)
}
args = append(args, pq.Array(severityStrings))
argIndex++
}
if len(filter.Statuses) > 0 {
conditions = append(conditions, fmt.Sprintf("status = ANY($%d)", argIndex))
statusStrings := make([]string, len(filter.Statuses))
for i, s := range filter.Statuses {
statusStrings[i] = string(s)
}
args = append(args, pq.Array(statusStrings))
argIndex++
}
if filter.ActorID != "" {
conditions = append(conditions, fmt.Sprintf("actor_id = $%d", argIndex))
args = append(args, filter.ActorID)
argIndex++
}
if filter.ActorType != "" {
conditions = append(conditions, fmt.Sprintf("actor_type = $%d", argIndex))
args = append(args, filter.ActorType)
argIndex++
}
if filter.TenantID != nil {
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
args = append(args, *filter.TenantID)
argIndex++
}
if filter.ResourceID != "" {
conditions = append(conditions, fmt.Sprintf("resource_id = $%d", argIndex))
args = append(args, filter.ResourceID)
argIndex++
}
if filter.ResourceType != "" {
conditions = append(conditions, fmt.Sprintf("resource_type = $%d", argIndex))
args = append(args, filter.ResourceType)
argIndex++
}
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *filter.StartTime)
argIndex++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *filter.EndTime)
argIndex++
}
if len(filter.Tags) > 0 {
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIndex))
args = append(args, pq.Array(filter.Tags))
argIndex++
}
// Build WHERE clause
if len(conditions) > 0 {
baseQuery += " WHERE " + strings.Join(conditions, " AND ")
}
// Add ORDER BY
orderBy := "timestamp"
if filter.OrderBy != "" {
switch filter.OrderBy {
case "timestamp", "type", "severity", "status":
orderBy = filter.OrderBy
}
}
direction := "DESC"
if !filter.OrderDesc {
direction = "ASC"
}
baseQuery += fmt.Sprintf(" ORDER BY %s %s", orderBy, direction)
// Add pagination
if filter.Limit <= 0 {
filter.Limit = 100
}
if filter.Limit > 1000 {
filter.Limit = 1000
}
baseQuery += fmt.Sprintf(" LIMIT $%d", argIndex)
args = append(args, filter.Limit)
argIndex++
if filter.Offset > 0 {
baseQuery += fmt.Sprintf(" OFFSET $%d", argIndex)
args = append(args, filter.Offset)
}
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, baseQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit events: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit events: %w", err)
}
return events, nil
}
// GetStats returns aggregated statistics for audit events
func (r *AuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) {
stats := &audit.AuditStats{
ByType: make(map[audit.EventType]int),
BySeverity: make(map[audit.EventSeverity]int),
ByStatus: make(map[audit.EventStatus]int),
}
// Build base conditions
var conditions []string
var args []interface{}
argIndex := 1
if len(filter.EventTypes) > 0 {
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
typeStrings := make([]string, len(filter.EventTypes))
for i, t := range filter.EventTypes {
typeStrings[i] = string(t)
}
args = append(args, pq.Array(typeStrings))
argIndex++
}
if filter.TenantID != nil {
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
args = append(args, *filter.TenantID)
argIndex++
}
if filter.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *filter.StartTime)
argIndex++
}
if filter.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *filter.EndTime)
argIndex++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
db := r.db.GetDB().(*sql.DB)
// Get total count
totalQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
err := db.QueryRowContext(ctx, totalQuery, args...).Scan(&stats.TotalEvents)
if err != nil {
return nil, fmt.Errorf("failed to get total event count: %w", err)
}
// Get stats by type
typeQuery := fmt.Sprintf(`
SELECT type, COUNT(*)
FROM audit_events %s
GROUP BY type
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err := db.QueryContext(ctx, typeQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get type stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var eventType string
var count int
if err := rows.Scan(&eventType, &count); err != nil {
return nil, fmt.Errorf("failed to scan type stats: %w", err)
}
stats.ByType[audit.EventType(eventType)] = count
}
// Get stats by severity
severityQuery := fmt.Sprintf(`
SELECT severity, COUNT(*)
FROM audit_events %s
GROUP BY severity
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err = db.QueryContext(ctx, severityQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get severity stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var severity string
var count int
if err := rows.Scan(&severity, &count); err != nil {
return nil, fmt.Errorf("failed to scan severity stats: %w", err)
}
stats.BySeverity[audit.EventSeverity(severity)] = count
}
// Get stats by status
statusQuery := fmt.Sprintf(`
SELECT status, COUNT(*)
FROM audit_events %s
GROUP BY status
ORDER BY COUNT(*) DESC
`, whereClause)
rows, err = db.QueryContext(ctx, statusQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get status stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
return nil, fmt.Errorf("failed to scan status stats: %w", err)
}
stats.ByStatus[audit.EventStatus(status)] = count
}
// Get time-based stats if requested
if filter.GroupBy != "" {
stats.ByTime = make(map[string]int)
var timeFormat string
switch filter.GroupBy {
case "hour":
timeFormat = "YYYY-MM-DD HH24:00"
case "day":
timeFormat = "YYYY-MM-DD"
default:
timeFormat = "YYYY-MM-DD"
}
timeQuery := fmt.Sprintf(`
SELECT TO_CHAR(timestamp, '%s') as time_group, COUNT(*)
FROM audit_events %s
GROUP BY time_group
ORDER BY time_group DESC
`, timeFormat, whereClause)
rows, err = db.QueryContext(ctx, timeQuery, args...)
if err != nil {
return nil, fmt.Errorf("failed to get time stats: %w", err)
}
defer rows.Close()
for rows.Next() {
var timeGroup string
var count int
if err := rows.Scan(&timeGroup, &count); err != nil {
return nil, fmt.Errorf("failed to scan time stats: %w", err)
}
stats.ByTime[timeGroup] = count
}
}
return stats, nil
}
// DeleteOldEvents removes audit events older than the specified time
func (r *AuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) {
query := `DELETE FROM audit_events WHERE timestamp < $1`
db := r.db.GetDB().(*sql.DB)
result, err := db.ExecContext(ctx, query, olderThan)
if err != nil {
return 0, fmt.Errorf("failed to delete old audit events: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// GetByID retrieves a specific audit event by its ID
func (r *AuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, eventID)
event, err := r.scanAuditEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("audit event with ID '%s' not found", eventID)
}
return nil, fmt.Errorf("failed to get audit event: %w", err)
}
return event, nil
}
// GetByRequestID retrieves all audit events for a specific request
func (r *AuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE request_id = $1
ORDER BY timestamp ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by request ID: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetBySession retrieves all audit events for a specific session
func (r *AuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) {
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE session_id = $1
ORDER BY timestamp ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by session ID: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetByActor retrieves audit events for a specific actor
func (r *AuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE actor_id = $1
ORDER BY timestamp DESC
LIMIT $2 OFFSET $3
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, actorID, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by actor: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// GetByResource retrieves audit events for a specific resource
func (r *AuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
query := `
SELECT id, type, severity, status, timestamp,
actor_id, actor_type, actor_ip, user_agent, tenant_id,
resource_id, resource_type, action, description, details,
request_id, session_id, tags, metadata
FROM audit_events
WHERE resource_type = $1 AND resource_id = $2
ORDER BY timestamp DESC
LIMIT $3 OFFSET $4
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, resourceType, resourceID, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query audit events by resource: %w", err)
}
defer rows.Close()
var events []*audit.AuditEvent
for rows.Next() {
event, err := r.scanAuditEvent(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
return events, nil
}
// scanAuditEvent scans a database row into an AuditEvent struct
func (r *AuditRepository) scanAuditEvent(row interface{}) (*audit.AuditEvent, error) {
event := &audit.AuditEvent{}
var typeStr, severityStr, statusStr string
var actorID, actorType, actorIP, userAgent sql.NullString
var tenantID *uuid.UUID
var resourceID, resourceType sql.NullString
var detailsJSON, metadataJSON string
var requestID, sessionID sql.NullString
var tags pq.StringArray
var scanner interface {
Scan(dest ...interface{}) error
}
switch v := row.(type) {
case *sql.Row:
scanner = v
case *sql.Rows:
scanner = v
default:
return nil, fmt.Errorf("invalid row type")
}
err := scanner.Scan(
&event.ID,
&typeStr,
&severityStr,
&statusStr,
&event.Timestamp,
&actorID,
&actorType,
&actorIP,
&userAgent,
&tenantID,
&resourceID,
&resourceType,
&event.Action,
&event.Description,
&detailsJSON,
&requestID,
&sessionID,
&tags,
&metadataJSON,
)
if err != nil {
return nil, err
}
// Convert string enums to types
event.Type = audit.EventType(typeStr)
event.Severity = audit.EventSeverity(severityStr)
event.Status = audit.EventStatus(statusStr)
// Handle nullable fields
if actorID.Valid {
event.ActorID = actorID.String
}
if actorType.Valid {
event.ActorType = actorType.String
}
if actorIP.Valid {
event.ActorIP = actorIP.String
}
if userAgent.Valid {
event.UserAgent = userAgent.String
}
if tenantID != nil {
event.TenantID = tenantID
}
if resourceID.Valid {
event.ResourceID = resourceID.String
}
if resourceType.Valid {
event.ResourceType = resourceType.String
}
if requestID.Valid {
event.RequestID = requestID.String
}
if sessionID.Valid {
event.SessionID = sessionID.String
}
// Convert tags
event.Tags = []string(tags)
// Parse JSON fields
if detailsJSON != "" {
if err := json.Unmarshal([]byte(detailsJSON), &event.Details); err != nil {
return nil, fmt.Errorf("failed to unmarshal details JSON: %w", err)
}
}
if metadataJSON != "" {
if err := json.Unmarshal([]byte(metadataJSON), &event.Metadata); err != nil {
return nil, fmt.Errorf("failed to unmarshal metadata JSON: %w", err)
}
}
return event, nil
}

View File

@ -0,0 +1,693 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
"github.com/lib/pq"
)
// PermissionRepository implements the PermissionRepository interface for PostgreSQL
type PermissionRepository struct {
db repository.DatabaseProvider
}
// NewPermissionRepository creates a new PostgreSQL permission repository
func NewPermissionRepository(db repository.DatabaseProvider) repository.PermissionRepository {
return &PermissionRepository{db: db}
}
// CreateAvailablePermission creates a new available permission
func (r *PermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error {
query := `
INSERT INTO available_permissions (
id, scope, name, description, category, parent_scope,
is_system, created_by, updated_by, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
db := r.db.GetDB().(*sql.DB)
now := time.Now()
if permission.ID == uuid.Nil {
permission.ID = uuid.New()
}
_, err := db.ExecContext(ctx, query,
permission.ID,
permission.Scope,
permission.Name,
permission.Description,
permission.Category,
permission.ParentScope,
permission.IsSystem,
permission.CreatedBy,
permission.UpdatedBy,
now,
now,
)
if err != nil {
return fmt.Errorf("failed to create available permission: %w", err)
}
permission.CreatedAt = now
permission.UpdatedAt = now
return nil
}
// GetAvailablePermission retrieves an available permission by ID
func (r *PermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) {
query := `
SELECT id, scope, name, description, category, parent_scope,
is_system, created_at, created_by, updated_at, updated_by
FROM available_permissions
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, permissionID)
permission := &domain.AvailablePermission{}
err := row.Scan(
&permission.ID,
&permission.Scope,
&permission.Name,
&permission.Description,
&permission.Category,
&permission.ParentScope,
&permission.IsSystem,
&permission.CreatedAt,
&permission.CreatedBy,
&permission.UpdatedAt,
&permission.UpdatedBy,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("permission with ID '%s' not found", permissionID)
}
return nil, fmt.Errorf("failed to get available permission: %w", err)
}
return permission, nil
}
// GetAvailablePermissionByScope retrieves an available permission by scope
func (r *PermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) {
query := `
SELECT id, scope, name, description, category, parent_scope,
is_system, created_at, created_by, updated_at, updated_by
FROM available_permissions
WHERE scope = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, scope)
permission := &domain.AvailablePermission{}
err := row.Scan(
&permission.ID,
&permission.Scope,
&permission.Name,
&permission.Description,
&permission.Category,
&permission.ParentScope,
&permission.IsSystem,
&permission.CreatedAt,
&permission.CreatedBy,
&permission.UpdatedAt,
&permission.UpdatedBy,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("permission with scope '%s' not found", scope)
}
return nil, fmt.Errorf("failed to get available permission by scope: %w", err)
}
return permission, nil
}
// ListAvailablePermissions retrieves available permissions with pagination and filtering
func (r *PermissionRepository) ListAvailablePermissions(ctx context.Context, category string, includeSystem bool, limit, offset int) ([]*domain.AvailablePermission, error) {
var args []interface{}
var whereClauses []string
argIndex := 1
// Build WHERE clause based on filters
if category != "" {
whereClauses = append(whereClauses, fmt.Sprintf("category = $%d", argIndex))
args = append(args, category)
argIndex++
}
if !includeSystem {
whereClauses = append(whereClauses, fmt.Sprintf("is_system = $%d", argIndex))
args = append(args, false)
argIndex++
}
whereClause := ""
if len(whereClauses) > 0 {
whereClause = "WHERE " + fmt.Sprintf("%s", whereClauses[0])
for i := 1; i < len(whereClauses); i++ {
whereClause += " AND " + whereClauses[i]
}
}
query := fmt.Sprintf(`
SELECT id, scope, name, description, category, parent_scope,
is_system, created_at, created_by, updated_at, updated_by
FROM available_permissions
%s
ORDER BY category, scope
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
args = append(args, limit, offset)
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to list available permissions: %w", err)
}
defer rows.Close()
var permissions []*domain.AvailablePermission
for rows.Next() {
permission := &domain.AvailablePermission{}
err := rows.Scan(
&permission.ID,
&permission.Scope,
&permission.Name,
&permission.Description,
&permission.Category,
&permission.ParentScope,
&permission.IsSystem,
&permission.CreatedAt,
&permission.CreatedBy,
&permission.UpdatedAt,
&permission.UpdatedBy,
)
if err != nil {
return nil, fmt.Errorf("failed to scan available permission: %w", err)
}
permissions = append(permissions, permission)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("failed to iterate available permissions: %w", err)
}
return permissions, nil
}
// UpdateAvailablePermission updates an available permission
func (r *PermissionRepository) UpdateAvailablePermission(ctx context.Context, permissionID uuid.UUID, permission *domain.AvailablePermission) error {
query := `
UPDATE available_permissions
SET scope = $2, name = $3, description = $4, category = $5,
parent_scope = $6, is_system = $7, updated_by = $8, updated_at = $9
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
now := time.Now()
result, err := db.ExecContext(ctx, query,
permissionID,
permission.Scope,
permission.Name,
permission.Description,
permission.Category,
permission.ParentScope,
permission.IsSystem,
permission.UpdatedBy,
now,
)
if err != nil {
return fmt.Errorf("failed to update available permission: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("permission with ID %s not found", permissionID)
}
permission.UpdatedAt = now
return nil
}
// DeleteAvailablePermission deletes an available permission
func (r *PermissionRepository) DeleteAvailablePermission(ctx context.Context, permissionID uuid.UUID) error {
// First check if the permission has any child permissions
checkChildrenQuery := `
SELECT COUNT(*) FROM available_permissions
WHERE parent_scope = (SELECT scope FROM available_permissions WHERE id = $1)
`
db := r.db.GetDB().(*sql.DB)
var childCount int
err := db.QueryRowContext(ctx, checkChildrenQuery, permissionID).Scan(&childCount)
if err != nil {
return fmt.Errorf("failed to check for child permissions: %w", err)
}
if childCount > 0 {
return fmt.Errorf("cannot delete permission: it has %d child permissions", childCount)
}
// Check if the permission is granted to any tokens
checkGrantsQuery := `
SELECT COUNT(*) FROM granted_permissions
WHERE permission_id = $1 AND revoked = false
`
var grantCount int
err = db.QueryRowContext(ctx, checkGrantsQuery, permissionID).Scan(&grantCount)
if err != nil {
return fmt.Errorf("failed to check for active grants: %w", err)
}
if grantCount > 0 {
return fmt.Errorf("cannot delete permission: it is currently granted to %d tokens", grantCount)
}
// Delete the permission
deleteQuery := `DELETE FROM available_permissions WHERE id = $1`
result, err := db.ExecContext(ctx, deleteQuery, permissionID)
if err != nil {
return fmt.Errorf("failed to delete available permission: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("permission with ID %s not found", permissionID)
}
return nil
}
// ValidatePermissionScopes checks if all given scopes exist and are valid
func (r *PermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) {
if len(scopes) == 0 {
return []string{}, nil
}
query := `
SELECT scope
FROM available_permissions
WHERE scope = ANY($1)
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, pq.Array(scopes))
if err != nil {
return nil, fmt.Errorf("failed to validate permission scopes: %w", err)
}
defer rows.Close()
validScopes := make(map[string]bool)
for rows.Next() {
var scope string
if err := rows.Scan(&scope); err != nil {
return nil, fmt.Errorf("failed to scan scope: %w", err)
}
validScopes[scope] = true
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating scopes: %w", err)
}
var result []string
for _, scope := range scopes {
if validScopes[scope] {
result = append(result, scope)
}
}
return result, nil
}
// GetPermissionHierarchy returns all parent and child permissions for given scopes
func (r *PermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {
if len(scopes) == 0 {
return []*domain.AvailablePermission{}, nil
}
// Use recursive CTE to get full hierarchy
query := `
WITH RECURSIVE permission_hierarchy AS (
-- Base case: get permissions matching the input scopes
SELECT id, scope, name, description, category, parent_scope,
is_system, created_at, created_by, updated_at, updated_by, 0 as level
FROM available_permissions
WHERE scope = ANY($1)
UNION ALL
-- Recursive case: get all parents and children
SELECT ap.id, ap.scope, ap.name, ap.description, ap.category, ap.parent_scope,
ap.is_system, ap.created_at, ap.created_by, ap.updated_at, ap.updated_by,
ph.level + 1 as level
FROM available_permissions ap
JOIN permission_hierarchy ph ON (
-- Get parents (where ap.scope = ph.parent_scope)
ap.scope = ph.parent_scope
OR
-- Get children (where ap.parent_scope = ph.scope)
ap.parent_scope = ph.scope
)
WHERE ph.level < 5 -- Prevent infinite recursion
)
SELECT DISTINCT id, scope, name, description, category, parent_scope,
is_system, created_at, created_by, updated_at, updated_by
FROM permission_hierarchy
ORDER BY scope
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, pq.Array(scopes))
if err != nil {
return nil, fmt.Errorf("failed to get permission hierarchy: %w", err)
}
defer rows.Close()
var permissions []*domain.AvailablePermission
for rows.Next() {
permission := &domain.AvailablePermission{}
err := rows.Scan(
&permission.ID,
&permission.Scope,
&permission.Name,
&permission.Description,
&permission.Category,
&permission.ParentScope,
&permission.IsSystem,
&permission.CreatedAt,
&permission.CreatedBy,
&permission.UpdatedAt,
&permission.UpdatedBy,
)
if err != nil {
return nil, fmt.Errorf("failed to scan permission hierarchy: %w", err)
}
permissions = append(permissions, permission)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("failed to iterate permission hierarchy: %w", err)
}
return permissions, nil
}
// GrantedPermissionRepository implements the GrantedPermissionRepository interface for PostgreSQL
type GrantedPermissionRepository struct {
db repository.DatabaseProvider
}
// NewGrantedPermissionRepository creates a new PostgreSQL granted permission repository
func NewGrantedPermissionRepository(db repository.DatabaseProvider) repository.GrantedPermissionRepository {
return &GrantedPermissionRepository{db: db}
}
// GrantPermissions grants multiple permissions to a token
func (r *GrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error {
if len(grants) == 0 {
return nil
}
db := r.db.GetDB().(*sql.DB)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
query := `
INSERT INTO granted_permissions (
id, token_type, token_id, permission_id, scope, created_by, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (token_type, token_id, permission_id) DO NOTHING
`
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
defer stmt.Close()
now := time.Now()
for _, grant := range grants {
if grant.ID == uuid.Nil {
grant.ID = uuid.New()
}
_, err = stmt.ExecContext(ctx,
grant.ID,
string(grant.TokenType),
grant.TokenID,
grant.PermissionID,
grant.Scope,
grant.CreatedBy,
now,
)
if err != nil {
return fmt.Errorf("failed to grant permission: %w", err)
}
grant.CreatedAt = now
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// GetGrantedPermissions retrieves all granted permissions for a token
func (r *GrantedPermissionRepository) GetGrantedPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]*domain.GrantedPermission, error) {
query := `
SELECT id, token_type, token_id, permission_id, scope, created_at, created_by, revoked
FROM granted_permissions
WHERE token_type = $1 AND token_id = $2 AND revoked = false
ORDER BY created_at ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID)
if err != nil {
return nil, fmt.Errorf("failed to query granted permissions: %w", err)
}
defer rows.Close()
var permissions []*domain.GrantedPermission
for rows.Next() {
perm := &domain.GrantedPermission{}
var tokenTypeStr string
err := rows.Scan(
&perm.ID,
&tokenTypeStr,
&perm.TokenID,
&perm.PermissionID,
&perm.Scope,
&perm.CreatedAt,
&perm.CreatedBy,
&perm.Revoked,
)
if err != nil {
return nil, fmt.Errorf("failed to scan granted permission: %w", err)
}
perm.TokenType = domain.TokenType(tokenTypeStr)
permissions = append(permissions, perm)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating granted permissions: %w", err)
}
return permissions, nil
}
// GetGrantedPermissionScopes retrieves only the scopes for a token (more efficient)
func (r *GrantedPermissionRepository) GetGrantedPermissionScopes(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]string, error) {
query := `
SELECT scope
FROM granted_permissions
WHERE token_type = $1 AND token_id = $2 AND revoked = false
ORDER BY scope ASC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID)
if err != nil {
return nil, fmt.Errorf("failed to query granted permission scopes: %w", err)
}
defer rows.Close()
var scopes []string
for rows.Next() {
var scope string
if err := rows.Scan(&scope); err != nil {
return nil, fmt.Errorf("failed to scan permission scope: %w", err)
}
scopes = append(scopes, scope)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating permission scopes: %w", err)
}
return scopes, nil
}
// RevokePermission revokes a specific permission from a token
func (r *GrantedPermissionRepository) RevokePermission(ctx context.Context, grantID uuid.UUID, revokedBy string) error {
query := `
UPDATE granted_permissions
SET revoked = true, revoked_by = $2, revoked_at = $3
WHERE id = $1 AND revoked = false
`
db := r.db.GetDB().(*sql.DB)
now := time.Now()
result, err := db.ExecContext(ctx, query, grantID, revokedBy, now)
if err != nil {
return fmt.Errorf("failed to revoke permission: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("permission grant with ID %s not found or already revoked", grantID)
}
return nil
}
// RevokeAllPermissions revokes all permissions from a token
func (r *GrantedPermissionRepository) RevokeAllPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, revokedBy string) error {
query := `
UPDATE granted_permissions
SET revoked = true, revoked_by = $3, revoked_at = $4
WHERE token_type = $1 AND token_id = $2 AND revoked = false
`
db := r.db.GetDB().(*sql.DB)
now := time.Now()
result, err := db.ExecContext(ctx, query, tokenType, tokenID, revokedBy, now)
if err != nil {
return fmt.Errorf("failed to revoke all permissions: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
// Note: rowsAffected being 0 is not necessarily an error here -
// the token might not have had any active permissions
_ = rowsAffected
return nil
}
// HasPermission checks if a token has a specific permission
func (r *GrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) {
query := `
SELECT 1
FROM granted_permissions gp
JOIN available_permissions ap ON gp.permission_id = ap.id
WHERE gp.token_type = $1
AND gp.token_id = $2
AND gp.scope = $3
AND gp.revoked = false
LIMIT 1
`
db := r.db.GetDB().(*sql.DB)
var exists int
err := db.QueryRowContext(ctx, query, string(tokenType), tokenID, scope).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("failed to check permission: %w", err)
}
return true, nil
}
// HasAnyPermission checks if a token has any of the specified permissions
func (r *GrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) {
if len(scopes) == 0 {
return make(map[string]bool), nil
}
query := `
SELECT gp.scope
FROM granted_permissions gp
JOIN available_permissions ap ON gp.permission_id = ap.id
WHERE gp.token_type = $1
AND gp.token_id = $2
AND gp.scope = ANY($3)
AND gp.revoked = false
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID, pq.Array(scopes))
if err != nil {
return nil, fmt.Errorf("failed to check permissions: %w", err)
}
defer rows.Close()
result := make(map[string]bool)
// Initialize all scopes as false
for _, scope := range scopes {
result[scope] = false
}
// Mark found permissions as true
for rows.Next() {
var scope string
if err := rows.Scan(&scope); err != nil {
return nil, fmt.Errorf("failed to scan permission scope: %w", err)
}
result[scope] = true
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating permission results: %w", err)
}
return result, nil
}

View File

@ -0,0 +1,624 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/repository"
)
// sessionRepository implements the SessionRepository interface
type sessionRepository struct {
db *sqlx.DB
logger *zap.Logger
}
// NewSessionRepository creates a new session repository
func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository {
return &sessionRepository{
db: db,
logger: logger,
}
}
// Create creates a new user session
func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error {
r.logger.Debug("Creating new session",
zap.String("user_id", session.UserID),
zap.String("app_id", session.AppID),
zap.String("session_type", string(session.SessionType)))
// Generate ID if not provided
if session.ID == uuid.Nil {
session.ID = uuid.New()
}
// Set timestamps
now := time.Now()
session.CreatedAt = now
session.UpdatedAt = now
session.LastActivity = now
// Serialize metadata
metadataJSON, err := json.Marshal(session.Metadata)
if err != nil {
return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err)
}
query := `
INSERT INTO user_sessions (
id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at, metadata
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
)`
_, err = r.db.ExecContext(ctx, query,
session.ID,
session.UserID,
session.AppID,
session.SessionType,
session.Status,
session.AccessToken,
session.RefreshToken,
session.IDToken,
session.IPAddress,
session.UserAgent,
session.LastActivity,
session.ExpiresAt,
session.CreatedAt,
session.UpdatedAt,
metadataJSON,
)
if err != nil {
r.logger.Error("Failed to create session", zap.Error(err))
return errors.NewInternalError("Failed to create session").WithInternal(err)
}
r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
return nil
}
// GetByID retrieves a session by its ID
func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String()))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE id = $1`
var session domain.UserSession
var metadataJSON []byte
var revokedAt sql.NullTime
var revokedBy sql.NullString
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(
&session.ID,
&session.UserID,
&session.AppID,
&session.SessionType,
&session.Status,
&session.AccessToken,
&session.RefreshToken,
&session.IDToken,
&session.IPAddress,
&session.UserAgent,
&session.LastActivity,
&session.ExpiresAt,
&session.CreatedAt,
&session.UpdatedAt,
&revokedAt,
&revokedBy,
&metadataJSON,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.NewNotFoundError("Session not found")
}
r.logger.Error("Failed to get session by ID", zap.Error(err))
return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err)
}
// Handle nullable fields
if revokedAt.Valid {
session.RevokedAt = &revokedAt.Time
}
if revokedBy.Valid {
session.RevokedBy = &revokedBy.String
}
// Deserialize metadata
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
}
r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String()))
return &session, nil
}
// GetByUserID retrieves all sessions for a user
func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1
ORDER BY created_at DESC`
return r.scanSessions(ctx, query, userID)
}
// GetByUserAndApp retrieves sessions for a specific user and application
func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting sessions by user and app",
zap.String("user_id", userID),
zap.String("app_id", appID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1 AND app_id = $2
ORDER BY created_at DESC`
return r.scanSessions(ctx, query, userID, appID)
}
// GetActiveByUserID retrieves all active sessions for a user
func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID))
query := `
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
WHERE user_id = $1 AND status = $2 AND expires_at > NOW()
ORDER BY last_activity DESC`
return r.scanSessions(ctx, query, userID, domain.SessionStatusActive)
}
// List retrieves sessions with filtering and pagination
func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
r.logger.Debug("Listing sessions with filters",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.Int("limit", req.Limit),
zap.Int("offset", req.Offset))
// Build WHERE clause dynamically
whereClause := "WHERE 1=1"
args := []interface{}{}
argIndex := 1
if req.UserID != "" {
whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex)
args = append(args, req.UserID)
argIndex++
}
if req.AppID != "" {
whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex)
args = append(args, req.AppID)
argIndex++
}
if req.Status != nil {
whereClause += fmt.Sprintf(" AND status = $%d", argIndex)
args = append(args, *req.Status)
argIndex++
}
if req.SessionType != nil {
whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex)
args = append(args, *req.SessionType)
argIndex++
}
if req.TenantID != "" {
whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex)
args = append(args, req.TenantID)
argIndex++
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause)
var total int
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
r.logger.Error("Failed to get session count", zap.Error(err))
return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err)
}
// Get sessions with pagination
query := fmt.Sprintf(`
SELECT id, user_id, app_id, session_type, status, access_token,
refresh_token, id_token, ip_address, user_agent,
last_activity, expires_at, created_at, updated_at,
revoked_at, revoked_by, metadata
FROM user_sessions
%s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1)
args = append(args, req.Limit, req.Offset)
sessions, err := r.scanSessions(ctx, query, args...)
if err != nil {
return nil, err
}
return &domain.SessionListResponse{
Sessions: sessions,
Total: total,
Limit: req.Limit,
Offset: req.Offset,
}, nil
}
// Update updates an existing session
func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
r.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
// Build UPDATE clause dynamically
setParts := []string{"updated_at = NOW()"}
args := []interface{}{}
argIndex := 1
if updates.Status != nil {
setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex))
args = append(args, *updates.Status)
argIndex++
}
if updates.LastActivity != nil {
setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex))
args = append(args, *updates.LastActivity)
argIndex++
}
if updates.ExpiresAt != nil {
setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex))
args = append(args, *updates.ExpiresAt)
argIndex++
}
if updates.IPAddress != nil {
setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex))
args = append(args, *updates.IPAddress)
argIndex++
}
if updates.UserAgent != nil {
setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex))
args = append(args, *updates.UserAgent)
argIndex++
}
if len(setParts) == 1 {
return errors.NewValidationError("No fields to update")
}
// Build the complete query
setClause := fmt.Sprintf("%s", setParts[0])
for i := 1; i < len(setParts); i++ {
setClause += fmt.Sprintf(", %s", setParts[i])
}
query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex)
args = append(args, sessionID)
result, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
r.logger.Error("Failed to update session", zap.Error(err))
return errors.NewInternalError("Failed to update session").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
return nil
}
// UpdateActivity updates the last activity timestamp for a session
func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error {
r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1`
result, err := r.db.ExecContext(ctx, query, sessionID)
if err != nil {
r.logger.Error("Failed to update session activity", zap.Error(err))
return errors.NewInternalError("Failed to update session activity").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
return nil
}
// Revoke revokes a session
func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
r.logger.Debug("Revoking session",
zap.String("session_id", sessionID.String()),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE id = $3`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID)
if err != nil {
r.logger.Error("Failed to revoke session", zap.Error(err))
return errors.NewInternalError("Failed to revoke session").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
if rowsAffected == 0 {
return errors.NewNotFoundError("Session not found")
}
r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
return nil
}
// RevokeAllByUser revokes all sessions for a user
func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error {
r.logger.Debug("Revoking all sessions for user",
zap.String("user_id", userID),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE user_id = $3 AND status = $4`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to revoke user sessions", zap.Error(err))
return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("User sessions revoked",
zap.String("user_id", userID),
zap.Int64("sessions_revoked", rowsAffected))
return nil
}
// RevokeAllByUserAndApp revokes all sessions for a user and application
func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error {
r.logger.Debug("Revoking all sessions for user and app",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("revoked_by", revokedBy))
query := `
UPDATE user_sessions
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
WHERE user_id = $3 AND app_id = $4 AND status = $5`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to revoke user app sessions", zap.Error(err))
return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("User app sessions revoked",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.Int64("sessions_revoked", rowsAffected))
return nil
}
// ExpireOldSessions marks expired sessions as expired
func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) {
r.logger.Debug("Expiring old sessions")
query := `
UPDATE user_sessions
SET status = $1, updated_at = NOW()
WHERE expires_at < NOW() AND status = $2`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive)
if err != nil {
r.logger.Error("Failed to expire old sessions", zap.Error(err))
return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected))
return int(rowsAffected), nil
}
// DeleteExpiredSessions removes expired sessions older than the specified duration
func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) {
r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan))
cutoffTime := time.Now().Add(-olderThan)
query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2`
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime)
if err != nil {
r.logger.Error("Failed to delete expired sessions", zap.Error(err))
return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
}
r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected))
return int(rowsAffected), nil
}
// Exists checks if a session exists
func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) {
r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String()))
query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)`
var exists bool
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists)
if err != nil {
r.logger.Error("Failed to check session existence", zap.Error(err))
return false, errors.NewInternalError("Failed to check session existence").WithInternal(err)
}
return exists, nil
}
// GetSessionCount returns the total number of sessions for a user
func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) {
r.logger.Debug("Getting session count for user", zap.String("user_id", userID))
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1`
var count int
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
if err != nil {
r.logger.Error("Failed to get session count", zap.Error(err))
return 0, errors.NewInternalError("Failed to get session count").WithInternal(err)
}
return count, nil
}
// GetActiveSessionCount returns the number of active sessions for a user
func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) {
r.logger.Debug("Getting active session count for user", zap.String("user_id", userID))
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()`
var count int
err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count)
if err != nil {
r.logger.Error("Failed to get active session count", zap.Error(err))
return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err)
}
return count, nil
}
// scanSessions is a helper method to scan multiple sessions from query results
func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) {
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
r.logger.Error("Failed to execute session query", zap.Error(err))
return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err)
}
defer rows.Close()
var sessions []*domain.UserSession
for rows.Next() {
var session domain.UserSession
var metadataJSON []byte
var revokedAt sql.NullTime
var revokedBy sql.NullString
err := rows.Scan(
&session.ID,
&session.UserID,
&session.AppID,
&session.SessionType,
&session.Status,
&session.AccessToken,
&session.RefreshToken,
&session.IDToken,
&session.IPAddress,
&session.UserAgent,
&session.LastActivity,
&session.ExpiresAt,
&session.CreatedAt,
&session.UpdatedAt,
&revokedAt,
&revokedBy,
&metadataJSON,
)
if err != nil {
r.logger.Error("Failed to scan session row", zap.Error(err))
return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err)
}
// Handle nullable fields
if revokedAt.Valid {
session.RevokedAt = &revokedAt.Time
}
if revokedBy.Valid {
session.RevokedBy = &revokedBy.String
}
// Deserialize metadata
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
}
sessions = append(sessions, &session)
}
if err := rows.Err(); err != nil {
r.logger.Error("Error iterating session rows", zap.Error(err))
return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err)
}
return sessions, nil
}

View File

@ -0,0 +1,290 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
)
// StaticTokenRepository implements the StaticTokenRepository interface for PostgreSQL
type StaticTokenRepository struct {
db repository.DatabaseProvider
}
// NewStaticTokenRepository creates a new PostgreSQL static token repository
func NewStaticTokenRepository(db repository.DatabaseProvider) repository.StaticTokenRepository {
return &StaticTokenRepository{db: db}
}
// Create creates a new static token
func (r *StaticTokenRepository) Create(ctx context.Context, token *domain.StaticToken) error {
query := `
INSERT INTO static_tokens (
id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
db := r.db.GetDB().(*sql.DB)
now := time.Now()
_, err := db.ExecContext(ctx, query,
token.ID,
token.AppID,
string(token.Owner.Type),
token.Owner.Name,
token.Owner.Owner,
token.KeyHash,
string(token.Type),
now,
now,
)
if err != nil {
return fmt.Errorf("failed to create static token: %w", err)
}
token.CreatedAt = now
token.UpdatedAt = now
return nil
}
// GetByID retrieves a static token by its ID
func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, tokenID)
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := row.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("static token with ID '%s' not found", tokenID)
}
return nil, fmt.Errorf("failed to get static token: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
return token, nil
}
// GetByKeyHash retrieves a static token by its key hash
func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
WHERE key_hash = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, keyHash)
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := row.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("static token with hash not found")
}
return nil, fmt.Errorf("failed to get static token by hash: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
return token, nil
}
// GetByAppID retrieves all static tokens for an application
func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
WHERE app_id = $1
ORDER BY created_at DESC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, appID)
if err != nil {
return nil, fmt.Errorf("failed to query static tokens: %w", err)
}
defer rows.Close()
var tokens []*domain.StaticToken
for rows.Next() {
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := rows.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan static token: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
tokens = append(tokens, token)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating static tokens: %w", err)
}
return tokens, nil
}
// List retrieves static tokens with pagination
func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query static tokens: %w", err)
}
defer rows.Close()
var tokens []*domain.StaticToken
for rows.Next() {
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := rows.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan static token: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
tokens = append(tokens, token)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating static tokens: %w", err)
}
return tokens, nil
}
// Delete deletes a static token
func (r *StaticTokenRepository) Delete(ctx context.Context, tokenID uuid.UUID) error {
query := `DELETE FROM static_tokens WHERE id = $1`
db := r.db.GetDB().(*sql.DB)
result, err := db.ExecContext(ctx, query, tokenID)
if err != nil {
return fmt.Errorf("failed to delete static token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("static token with ID '%s' not found", tokenID)
}
return nil
}
// Exists checks if a static token exists
func (r *StaticTokenRepository) Exists(ctx context.Context, tokenID uuid.UUID) (bool, error) {
query := `SELECT 1 FROM static_tokens WHERE id = $1`
db := r.db.GetDB().(*sql.DB)
var exists int
err := db.QueryRowContext(ctx, query, tokenID).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("failed to check static token existence: %w", err)
}
return true, nil
}