org
This commit is contained in:
352
kms/internal/repository/interfaces.go
Normal file
352
kms/internal/repository/interfaces.go
Normal file
@ -0,0 +1,352 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kms/api-key-service/internal/audit"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
)
|
||||
|
||||
// ApplicationRepository defines the interface for application data operations
|
||||
type ApplicationRepository interface {
|
||||
// Create creates a new application
|
||||
Create(ctx context.Context, app *domain.Application) error
|
||||
|
||||
// GetByID retrieves an application by its ID
|
||||
GetByID(ctx context.Context, appID string) (*domain.Application, error)
|
||||
|
||||
// List retrieves applications with pagination
|
||||
List(ctx context.Context, limit, offset int) ([]*domain.Application, error)
|
||||
|
||||
// Update updates an existing application
|
||||
Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error)
|
||||
|
||||
// Delete deletes an application
|
||||
Delete(ctx context.Context, appID string) error
|
||||
|
||||
// Exists checks if an application exists
|
||||
Exists(ctx context.Context, appID string) (bool, error)
|
||||
}
|
||||
|
||||
// StaticTokenRepository defines the interface for static token data operations
|
||||
type StaticTokenRepository interface {
|
||||
// Create creates a new static token
|
||||
Create(ctx context.Context, token *domain.StaticToken) error
|
||||
|
||||
// GetByID retrieves a static token by its ID
|
||||
GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error)
|
||||
|
||||
// GetByKeyHash retrieves a static token by its key hash
|
||||
GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error)
|
||||
|
||||
// GetByAppID retrieves all static tokens for an application
|
||||
GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error)
|
||||
|
||||
// List retrieves static tokens with pagination
|
||||
List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error)
|
||||
|
||||
// Delete deletes a static token
|
||||
Delete(ctx context.Context, tokenID uuid.UUID) error
|
||||
|
||||
// Exists checks if a static token exists
|
||||
Exists(ctx context.Context, tokenID uuid.UUID) (bool, error)
|
||||
}
|
||||
|
||||
// PermissionRepository defines the interface for permission data operations
|
||||
type PermissionRepository interface {
|
||||
// CreateAvailablePermission creates a new available permission
|
||||
CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error
|
||||
|
||||
// GetAvailablePermission retrieves an available permission by ID
|
||||
GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error)
|
||||
|
||||
// GetAvailablePermissionByScope retrieves an available permission by scope
|
||||
GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error)
|
||||
|
||||
// ListAvailablePermissions retrieves available permissions with pagination and filtering
|
||||
ListAvailablePermissions(ctx context.Context, category string, includeSystem bool, limit, offset int) ([]*domain.AvailablePermission, error)
|
||||
|
||||
// UpdateAvailablePermission updates an available permission
|
||||
UpdateAvailablePermission(ctx context.Context, permissionID uuid.UUID, permission *domain.AvailablePermission) error
|
||||
|
||||
// DeleteAvailablePermission deletes an available permission
|
||||
DeleteAvailablePermission(ctx context.Context, permissionID uuid.UUID) error
|
||||
|
||||
// ValidatePermissionScopes checks if all given scopes exist and are valid
|
||||
ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) // returns invalid scopes
|
||||
|
||||
// GetPermissionHierarchy returns all parent and child permissions for given scopes
|
||||
GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error)
|
||||
}
|
||||
|
||||
// GrantedPermissionRepository defines the interface for granted permission operations
|
||||
type GrantedPermissionRepository interface {
|
||||
// GrantPermissions grants multiple permissions to a token
|
||||
GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error
|
||||
|
||||
// GetGrantedPermissions retrieves all granted permissions for a token
|
||||
GetGrantedPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]*domain.GrantedPermission, error)
|
||||
|
||||
// GetGrantedPermissionScopes retrieves only the scopes for a token (more efficient)
|
||||
GetGrantedPermissionScopes(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]string, error)
|
||||
|
||||
// RevokePermission revokes a specific permission from a token
|
||||
RevokePermission(ctx context.Context, grantID uuid.UUID, revokedBy string) error
|
||||
|
||||
// RevokeAllPermissions revokes all permissions from a token
|
||||
RevokeAllPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, revokedBy string) error
|
||||
|
||||
// HasPermission checks if a token has a specific permission
|
||||
HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error)
|
||||
|
||||
// HasAnyPermission checks if a token has any of the specified permissions
|
||||
HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error)
|
||||
}
|
||||
|
||||
// SessionRepository defines the interface for user session data operations
|
||||
type SessionRepository interface {
|
||||
// Create creates a new user session
|
||||
Create(ctx context.Context, session *domain.UserSession) error
|
||||
|
||||
// GetByID retrieves a session by its ID
|
||||
GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error)
|
||||
|
||||
// GetByUserID retrieves all sessions for a user
|
||||
GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetByUserAndApp retrieves sessions for a specific user and application
|
||||
GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error)
|
||||
|
||||
// GetActiveByUserID retrieves all active sessions for a user
|
||||
GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error)
|
||||
|
||||
// List retrieves sessions with filtering and pagination
|
||||
List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error)
|
||||
|
||||
// Update updates an existing session
|
||||
Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a session
|
||||
UpdateActivity(ctx context.Context, sessionID uuid.UUID) error
|
||||
|
||||
// Revoke revokes a session
|
||||
Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error
|
||||
|
||||
// RevokeAllByUser revokes all sessions for a user
|
||||
RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error
|
||||
|
||||
// RevokeAllByUserAndApp revokes all sessions for a user and application
|
||||
RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error
|
||||
|
||||
// ExpireOldSessions marks expired sessions as expired
|
||||
ExpireOldSessions(ctx context.Context) (int, error)
|
||||
|
||||
// DeleteExpiredSessions removes expired sessions older than the specified duration
|
||||
DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error)
|
||||
|
||||
// Exists checks if a session exists
|
||||
Exists(ctx context.Context, sessionID uuid.UUID) (bool, error)
|
||||
|
||||
// GetSessionCount returns the total number of sessions for a user
|
||||
GetSessionCount(ctx context.Context, userID string) (int, error)
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions for a user
|
||||
GetActiveSessionCount(ctx context.Context, userID string) (int, error)
|
||||
}
|
||||
|
||||
// DatabaseProvider defines the interface for database operations
|
||||
type DatabaseProvider interface {
|
||||
// GetDB returns the underlying database connection
|
||||
GetDB() interface{}
|
||||
|
||||
// Ping checks the database connection
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Close closes all database connections
|
||||
Close() error
|
||||
|
||||
// BeginTx starts a database transaction
|
||||
BeginTx(ctx context.Context) (TransactionProvider, error)
|
||||
}
|
||||
|
||||
// TransactionProvider defines the interface for database transaction operations
|
||||
type TransactionProvider interface {
|
||||
// Commit commits the transaction
|
||||
Commit() error
|
||||
|
||||
// Rollback rolls back the transaction
|
||||
Rollback() error
|
||||
|
||||
// GetTx returns the underlying transaction
|
||||
GetTx() interface{}
|
||||
}
|
||||
|
||||
// CacheProvider defines the interface for caching operations
|
||||
type CacheProvider interface {
|
||||
// Get retrieves a value from cache
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
|
||||
// Set stores a value in cache with expiration
|
||||
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
|
||||
|
||||
// Delete removes a value from cache
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// Exists checks if a key exists in cache
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Flush clears all cache entries
|
||||
Flush(ctx context.Context) error
|
||||
|
||||
// Close closes the cache connection
|
||||
Close() error
|
||||
}
|
||||
|
||||
// TokenProvider defines the interface for token operations
|
||||
type TokenProvider interface {
|
||||
// GenerateUserToken generates a JWT token for user authentication
|
||||
GenerateUserToken(ctx context.Context, userToken *domain.UserToken, hmacKey string) (string, error)
|
||||
|
||||
// ValidateUserToken validates and parses a JWT token
|
||||
ValidateUserToken(ctx context.Context, token string, hmacKey string) (*domain.UserToken, error)
|
||||
|
||||
// GenerateStaticToken generates a static API key
|
||||
GenerateStaticToken(ctx context.Context) (string, error)
|
||||
|
||||
// HashStaticToken creates a secure hash of a static token
|
||||
HashStaticToken(ctx context.Context, token string) (string, error)
|
||||
|
||||
// ValidateStaticToken validates a static token against its hash
|
||||
ValidateStaticToken(ctx context.Context, token, hash string) (bool, error)
|
||||
|
||||
// RenewUserToken renews a user token while preserving max validity
|
||||
RenewUserToken(ctx context.Context, currentToken *domain.UserToken, renewalDuration time.Duration, hmacKey string) (string, error)
|
||||
}
|
||||
|
||||
// HashProvider defines the interface for cryptographic hashing operations
|
||||
type HashProvider interface {
|
||||
// Hash creates a secure hash of the input
|
||||
Hash(ctx context.Context, input string) (string, error)
|
||||
|
||||
// Compare compares an input against a hash
|
||||
Compare(ctx context.Context, input, hash string) (bool, error)
|
||||
|
||||
// GenerateKey generates a secure random key
|
||||
GenerateKey(ctx context.Context, length int) (string, error)
|
||||
}
|
||||
|
||||
// LoggerProvider defines the interface for logging operations
|
||||
type LoggerProvider interface {
|
||||
// Info logs an info level message
|
||||
Info(ctx context.Context, msg string, fields ...interface{})
|
||||
|
||||
// Warn logs a warning level message
|
||||
Warn(ctx context.Context, msg string, fields ...interface{})
|
||||
|
||||
// Error logs an error level message
|
||||
Error(ctx context.Context, msg string, err error, fields ...interface{})
|
||||
|
||||
// Debug logs a debug level message
|
||||
Debug(ctx context.Context, msg string, fields ...interface{})
|
||||
|
||||
// With returns a logger with additional fields
|
||||
With(fields ...interface{}) LoggerProvider
|
||||
}
|
||||
|
||||
// ConfigProvider defines the interface for configuration operations
|
||||
type ConfigProvider interface {
|
||||
// GetString retrieves a string configuration value
|
||||
GetString(key string) string
|
||||
|
||||
// GetInt retrieves an integer configuration value
|
||||
GetInt(key string) int
|
||||
|
||||
// GetBool retrieves a boolean configuration value
|
||||
GetBool(key string) bool
|
||||
|
||||
// GetDuration retrieves a duration configuration value
|
||||
GetDuration(key string) time.Duration
|
||||
|
||||
// GetStringSlice retrieves a string slice configuration value
|
||||
GetStringSlice(key string) []string
|
||||
|
||||
// IsSet checks if a configuration key is set
|
||||
IsSet(key string) bool
|
||||
|
||||
// Validate validates all required configuration values
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// AuthenticationProvider defines the interface for user authentication
|
||||
type AuthenticationProvider interface {
|
||||
// GetUserID extracts the user ID from the request context/headers
|
||||
GetUserID(ctx context.Context) (string, error)
|
||||
|
||||
// ValidateUser validates if the user is authentic
|
||||
ValidateUser(ctx context.Context, userID string) error
|
||||
|
||||
// GetUserClaims retrieves additional user information/claims
|
||||
GetUserClaims(ctx context.Context, userID string) (map[string]string, error)
|
||||
|
||||
// Name returns the provider name for identification
|
||||
Name() string
|
||||
}
|
||||
|
||||
// RateLimitProvider defines the interface for rate limiting operations
|
||||
type RateLimitProvider interface {
|
||||
// Allow checks if a request should be allowed for the given identifier
|
||||
Allow(ctx context.Context, identifier string) (bool, error)
|
||||
|
||||
// Remaining returns the number of remaining requests for the identifier
|
||||
Remaining(ctx context.Context, identifier string) (int, error)
|
||||
|
||||
// Reset returns when the rate limit will reset for the identifier
|
||||
Reset(ctx context.Context, identifier string) (time.Time, error)
|
||||
}
|
||||
|
||||
// MetricsProvider defines the interface for metrics collection
|
||||
type MetricsProvider interface {
|
||||
// IncrementCounter increments a counter metric
|
||||
IncrementCounter(ctx context.Context, name string, labels map[string]string)
|
||||
|
||||
// RecordHistogram records a value in a histogram
|
||||
RecordHistogram(ctx context.Context, name string, value float64, labels map[string]string)
|
||||
|
||||
// SetGauge sets a gauge metric value
|
||||
SetGauge(ctx context.Context, name string, value float64, labels map[string]string)
|
||||
|
||||
// RecordDuration records the duration of an operation
|
||||
RecordDuration(ctx context.Context, name string, duration time.Duration, labels map[string]string)
|
||||
}
|
||||
|
||||
// AuditRepository defines the interface for audit event storage operations
|
||||
type AuditRepository interface {
|
||||
// Create stores a new audit event
|
||||
Create(ctx context.Context, event *audit.AuditEvent) error
|
||||
|
||||
// Query retrieves audit events based on filter criteria
|
||||
Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error)
|
||||
|
||||
// GetStats returns aggregated statistics for audit events
|
||||
GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error)
|
||||
|
||||
// DeleteOldEvents removes audit events older than the specified time
|
||||
DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error)
|
||||
|
||||
// GetByID retrieves a specific audit event by its ID
|
||||
GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error)
|
||||
|
||||
// GetByRequestID retrieves all audit events for a specific request
|
||||
GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error)
|
||||
|
||||
// GetBySession retrieves all audit events for a specific session
|
||||
GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error)
|
||||
|
||||
// GetByActor retrieves audit events for a specific actor
|
||||
GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error)
|
||||
|
||||
// GetByResource retrieves audit events for a specific resource
|
||||
GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error)
|
||||
}
|
||||
387
kms/internal/repository/postgres/application_repository.go
Normal file
387
kms/internal/repository/postgres/application_repository.go
Normal file
@ -0,0 +1,387 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// ApplicationRepository implements the ApplicationRepository interface for PostgreSQL
|
||||
type ApplicationRepository struct {
|
||||
db repository.DatabaseProvider
|
||||
}
|
||||
|
||||
// NewApplicationRepository creates a new PostgreSQL application repository
|
||||
func NewApplicationRepository(db repository.DatabaseProvider) repository.ApplicationRepository {
|
||||
return &ApplicationRepository{db: db}
|
||||
}
|
||||
|
||||
// Create creates a new application
|
||||
func (r *ApplicationRepository) Create(ctx context.Context, app *domain.Application) error {
|
||||
query := `
|
||||
INSERT INTO applications (
|
||||
app_id, app_link, type, callback_url, hmac_key, token_prefix,
|
||||
token_renewal_duration, max_token_duration,
|
||||
owner_type, owner_name, owner_owner,
|
||||
created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
now := time.Now()
|
||||
|
||||
// Convert application types to string array
|
||||
typeStrings := make([]string, len(app.Type))
|
||||
for i, t := range app.Type {
|
||||
typeStrings[i] = string(t)
|
||||
}
|
||||
|
||||
_, err := db.ExecContext(ctx, query,
|
||||
app.AppID,
|
||||
app.AppLink,
|
||||
pq.Array(typeStrings),
|
||||
app.CallbackURL,
|
||||
app.HMACKey,
|
||||
app.TokenPrefix,
|
||||
app.TokenRenewalDuration.Duration.Nanoseconds(),
|
||||
app.MaxTokenDuration.Duration.Nanoseconds(),
|
||||
string(app.Owner.Type),
|
||||
app.Owner.Name,
|
||||
app.Owner.Owner,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return fmt.Errorf("application with ID '%s' already exists", app.AppID)
|
||||
}
|
||||
return fmt.Errorf("failed to create application: %w", err)
|
||||
}
|
||||
|
||||
app.CreatedAt = now
|
||||
app.UpdatedAt = now
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves an application by its ID
|
||||
func (r *ApplicationRepository) GetByID(ctx context.Context, appID string) (*domain.Application, error) {
|
||||
query := `
|
||||
SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix,
|
||||
token_renewal_duration, max_token_duration,
|
||||
owner_type, owner_name, owner_owner,
|
||||
created_at, updated_at
|
||||
FROM applications
|
||||
WHERE app_id = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, appID)
|
||||
|
||||
app := &domain.Application{}
|
||||
var typeStrings pq.StringArray
|
||||
var tokenRenewalNanos, maxTokenNanos int64
|
||||
var ownerType string
|
||||
|
||||
err := row.Scan(
|
||||
&app.AppID,
|
||||
&app.AppLink,
|
||||
&typeStrings,
|
||||
&app.CallbackURL,
|
||||
&app.HMACKey,
|
||||
&app.TokenPrefix,
|
||||
&tokenRenewalNanos,
|
||||
&maxTokenNanos,
|
||||
&ownerType,
|
||||
&app.Owner.Name,
|
||||
&app.Owner.Owner,
|
||||
&app.CreatedAt,
|
||||
&app.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("application with ID '%s' not found", appID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get application: %w", err)
|
||||
}
|
||||
|
||||
// Convert string array to application types
|
||||
app.Type = make([]domain.ApplicationType, len(typeStrings))
|
||||
for i, t := range typeStrings {
|
||||
app.Type[i] = domain.ApplicationType(t)
|
||||
}
|
||||
|
||||
// Convert nanoseconds to duration
|
||||
app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)}
|
||||
app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)}
|
||||
|
||||
// Convert owner type
|
||||
app.Owner.Type = domain.OwnerType(ownerType)
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// List retrieves applications with pagination
|
||||
func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]*domain.Application, error) {
|
||||
query := `
|
||||
SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix,
|
||||
token_renewal_duration, max_token_duration,
|
||||
owner_type, owner_name, owner_owner,
|
||||
created_at, updated_at
|
||||
FROM applications
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, limit, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list applications: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var applications []*domain.Application
|
||||
|
||||
for rows.Next() {
|
||||
app := &domain.Application{}
|
||||
var typeStrings pq.StringArray
|
||||
var tokenRenewalNanos, maxTokenNanos int64
|
||||
var ownerType string
|
||||
|
||||
err := rows.Scan(
|
||||
&app.AppID,
|
||||
&app.AppLink,
|
||||
&typeStrings,
|
||||
&app.CallbackURL,
|
||||
&app.HMACKey,
|
||||
&app.TokenPrefix,
|
||||
&tokenRenewalNanos,
|
||||
&maxTokenNanos,
|
||||
&ownerType,
|
||||
&app.Owner.Name,
|
||||
&app.Owner.Owner,
|
||||
&app.CreatedAt,
|
||||
&app.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan application: %w", err)
|
||||
}
|
||||
|
||||
// Convert string array to application types
|
||||
app.Type = make([]domain.ApplicationType, len(typeStrings))
|
||||
for i, t := range typeStrings {
|
||||
app.Type[i] = domain.ApplicationType(t)
|
||||
}
|
||||
|
||||
// Convert nanoseconds to duration
|
||||
app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)}
|
||||
app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)}
|
||||
|
||||
// Convert owner type
|
||||
app.Owner.Type = domain.OwnerType(ownerType)
|
||||
|
||||
applications = append(applications, app)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate applications: %w", err)
|
||||
}
|
||||
|
||||
return applications, nil
|
||||
}
|
||||
|
||||
// Update updates an existing application
|
||||
func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) {
|
||||
// Build secure dynamic update query using a whitelist approach
|
||||
var setParts []string
|
||||
var args []interface{}
|
||||
argIndex := 1
|
||||
|
||||
// Whitelist of allowed fields to prevent SQL injection
|
||||
allowedFields := map[string]string{
|
||||
"app_link": "app_link",
|
||||
"type": "type",
|
||||
"callback_url": "callback_url",
|
||||
"hmac_key": "hmac_key",
|
||||
"token_prefix": "token_prefix",
|
||||
"token_renewal_duration": "token_renewal_duration",
|
||||
"max_token_duration": "max_token_duration",
|
||||
"owner_type": "owner_type",
|
||||
"owner_name": "owner_name",
|
||||
"owner_owner": "owner_owner",
|
||||
}
|
||||
|
||||
if updates.AppLink != nil {
|
||||
if field, ok := allowedFields["app_link"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.AppLink)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.Type != nil {
|
||||
if field, ok := allowedFields["type"]; ok {
|
||||
typeStrings := make([]string, len(*updates.Type))
|
||||
for i, t := range *updates.Type {
|
||||
typeStrings[i] = string(t)
|
||||
}
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, pq.Array(typeStrings))
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.CallbackURL != nil {
|
||||
if field, ok := allowedFields["callback_url"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.CallbackURL)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.HMACKey != nil {
|
||||
if field, ok := allowedFields["hmac_key"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.HMACKey)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.TokenPrefix != nil {
|
||||
if field, ok := allowedFields["token_prefix"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.TokenPrefix)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.TokenRenewalDuration != nil {
|
||||
if field, ok := allowedFields["token_renewal_duration"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.MaxTokenDuration != nil {
|
||||
if field, ok := allowedFields["max_token_duration"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.Owner != nil {
|
||||
if field, ok := allowedFields["owner_type"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, string(updates.Owner.Type))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if field, ok := allowedFields["owner_name"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.Owner.Name)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if field, ok := allowedFields["owner_owner"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.Owner.Owner)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if len(setParts) == 0 {
|
||||
return r.GetByID(ctx, appID) // No updates, return current state
|
||||
}
|
||||
|
||||
// Always update the updated_at field - using literal field name for security
|
||||
setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex))
|
||||
args = append(args, time.Now())
|
||||
argIndex++
|
||||
|
||||
// Add WHERE clause parameter
|
||||
args = append(args, appID)
|
||||
|
||||
// Build the final query with properly parameterized placeholders
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE applications
|
||||
SET %s
|
||||
WHERE app_id = $%d
|
||||
`, strings.Join(setParts, ", "), argIndex)
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
result, err := db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update application: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return nil, fmt.Errorf("application with ID '%s' not found", appID)
|
||||
}
|
||||
|
||||
// Return updated application
|
||||
return r.GetByID(ctx, appID)
|
||||
}
|
||||
|
||||
// Delete deletes an application
|
||||
func (r *ApplicationRepository) Delete(ctx context.Context, appID string) error {
|
||||
query := `DELETE FROM applications WHERE app_id = $1`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
result, err := db.ExecContext(ctx, query, appID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete application: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("application with ID '%s' not found", appID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if an application exists
|
||||
func (r *ApplicationRepository) Exists(ctx context.Context, appID string) (bool, error) {
|
||||
query := `SELECT 1 FROM applications WHERE app_id = $1`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
var exists int
|
||||
err := db.QueryRowContext(ctx, query, appID).Scan(&exists)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check application existence: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// isUniqueViolation checks if the error is a unique constraint violation
|
||||
func isUniqueViolation(err error) bool {
|
||||
if pqErr, ok := err.(*pq.Error); ok {
|
||||
return pqErr.Code == "23505" // unique_violation
|
||||
}
|
||||
return false
|
||||
}
|
||||
742
kms/internal/repository/postgres/audit_repository.go
Normal file
742
kms/internal/repository/postgres/audit_repository.go
Normal file
@ -0,0 +1,742 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/kms/api-key-service/internal/audit"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// AuditRepository implements the AuditRepository interface for PostgreSQL
|
||||
type AuditRepository struct {
|
||||
db repository.DatabaseProvider
|
||||
}
|
||||
|
||||
// NewAuditRepository creates a new PostgreSQL audit repository
|
||||
func NewAuditRepository(db repository.DatabaseProvider) repository.AuditRepository {
|
||||
return &AuditRepository{db: db}
|
||||
}
|
||||
|
||||
// Create stores a new audit event
|
||||
func (r *AuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error {
|
||||
query := `
|
||||
INSERT INTO audit_events (
|
||||
id, type, severity, status, timestamp,
|
||||
actor_id, actor_type, actor_ip, user_agent, tenant_id,
|
||||
resource_id, resource_type, action, description, details,
|
||||
request_id, session_id, tags, metadata
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14, $15, $16, $17, $18, $19
|
||||
)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
|
||||
// Ensure event has an ID and timestamp
|
||||
if event.ID == uuid.Nil {
|
||||
event.ID = uuid.New()
|
||||
}
|
||||
if event.Timestamp.IsZero() {
|
||||
event.Timestamp = time.Now().UTC()
|
||||
}
|
||||
|
||||
// Convert details to JSON
|
||||
var detailsJSON []byte
|
||||
var err error
|
||||
if event.Details != nil {
|
||||
detailsJSON, err = json.Marshal(event.Details)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal event details: %w", err)
|
||||
}
|
||||
} else {
|
||||
detailsJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// Convert metadata to JSON
|
||||
var metadataJSON []byte
|
||||
if event.Metadata != nil {
|
||||
metadataJSON, err = json.Marshal(event.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal event metadata: %w", err)
|
||||
}
|
||||
} else {
|
||||
metadataJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
var actorID, actorType, actorIP, userAgent *string
|
||||
var tenantID *uuid.UUID
|
||||
var resourceID, resourceType *string
|
||||
var requestID, sessionID *string
|
||||
|
||||
if event.ActorID != "" {
|
||||
actorID = &event.ActorID
|
||||
}
|
||||
if event.ActorType != "" {
|
||||
actorType = &event.ActorType
|
||||
}
|
||||
if event.ActorIP != "" {
|
||||
actorIP = &event.ActorIP
|
||||
}
|
||||
if event.UserAgent != "" {
|
||||
userAgent = &event.UserAgent
|
||||
}
|
||||
if event.TenantID != nil {
|
||||
tenantID = event.TenantID
|
||||
}
|
||||
if event.ResourceID != "" {
|
||||
resourceID = &event.ResourceID
|
||||
}
|
||||
if event.ResourceType != "" {
|
||||
resourceType = &event.ResourceType
|
||||
}
|
||||
if event.RequestID != "" {
|
||||
requestID = &event.RequestID
|
||||
}
|
||||
if event.SessionID != "" {
|
||||
sessionID = &event.SessionID
|
||||
}
|
||||
|
||||
_, err = db.ExecContext(ctx, query,
|
||||
event.ID,
|
||||
string(event.Type),
|
||||
string(event.Severity),
|
||||
string(event.Status),
|
||||
event.Timestamp,
|
||||
actorID,
|
||||
actorType,
|
||||
actorIP,
|
||||
userAgent,
|
||||
tenantID,
|
||||
resourceID,
|
||||
resourceType,
|
||||
event.Action,
|
||||
event.Description,
|
||||
string(detailsJSON),
|
||||
requestID,
|
||||
sessionID,
|
||||
pq.Array(event.Tags),
|
||||
string(metadataJSON),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create audit event: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query retrieves audit events based on filter criteria
|
||||
func (r *AuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) {
|
||||
// Build dynamic query with filters
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
argIndex := 1
|
||||
|
||||
baseQuery := `
|
||||
SELECT id, type, severity, status, timestamp,
|
||||
actor_id, actor_type, actor_ip, user_agent, tenant_id,
|
||||
resource_id, resource_type, action, description, details,
|
||||
request_id, session_id, tags, metadata
|
||||
FROM audit_events
|
||||
`
|
||||
|
||||
// Add filters
|
||||
if len(filter.EventTypes) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
|
||||
typeStrings := make([]string, len(filter.EventTypes))
|
||||
for i, t := range filter.EventTypes {
|
||||
typeStrings[i] = string(t)
|
||||
}
|
||||
args = append(args, pq.Array(typeStrings))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(filter.Severities) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("severity = ANY($%d)", argIndex))
|
||||
severityStrings := make([]string, len(filter.Severities))
|
||||
for i, s := range filter.Severities {
|
||||
severityStrings[i] = string(s)
|
||||
}
|
||||
args = append(args, pq.Array(severityStrings))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(filter.Statuses) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("status = ANY($%d)", argIndex))
|
||||
statusStrings := make([]string, len(filter.Statuses))
|
||||
for i, s := range filter.Statuses {
|
||||
statusStrings[i] = string(s)
|
||||
}
|
||||
args = append(args, pq.Array(statusStrings))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.ActorID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("actor_id = $%d", argIndex))
|
||||
args = append(args, filter.ActorID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.ActorType != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("actor_type = $%d", argIndex))
|
||||
args = append(args, filter.ActorType)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.TenantID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
|
||||
args = append(args, *filter.TenantID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.ResourceID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("resource_id = $%d", argIndex))
|
||||
args = append(args, filter.ResourceID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.ResourceType != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("resource_type = $%d", argIndex))
|
||||
args = append(args, filter.ResourceType)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
|
||||
args = append(args, *filter.StartTime)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
|
||||
args = append(args, *filter.EndTime)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(filter.Tags) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIndex))
|
||||
args = append(args, pq.Array(filter.Tags))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
// Build WHERE clause
|
||||
if len(conditions) > 0 {
|
||||
baseQuery += " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
// Add ORDER BY
|
||||
orderBy := "timestamp"
|
||||
if filter.OrderBy != "" {
|
||||
switch filter.OrderBy {
|
||||
case "timestamp", "type", "severity", "status":
|
||||
orderBy = filter.OrderBy
|
||||
}
|
||||
}
|
||||
|
||||
direction := "DESC"
|
||||
if !filter.OrderDesc {
|
||||
direction = "ASC"
|
||||
}
|
||||
|
||||
baseQuery += fmt.Sprintf(" ORDER BY %s %s", orderBy, direction)
|
||||
|
||||
// Add pagination
|
||||
if filter.Limit <= 0 {
|
||||
filter.Limit = 100
|
||||
}
|
||||
if filter.Limit > 1000 {
|
||||
filter.Limit = 1000
|
||||
}
|
||||
|
||||
baseQuery += fmt.Sprintf(" LIMIT $%d", argIndex)
|
||||
args = append(args, filter.Limit)
|
||||
argIndex++
|
||||
|
||||
if filter.Offset > 0 {
|
||||
baseQuery += fmt.Sprintf(" OFFSET $%d", argIndex)
|
||||
args = append(args, filter.Offset)
|
||||
}
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, baseQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*audit.AuditEvent
|
||||
for rows.Next() {
|
||||
event, err := r.scanAuditEvent(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit event: %w", err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating audit events: %w", err)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// GetStats returns aggregated statistics for audit events
|
||||
func (r *AuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) {
|
||||
stats := &audit.AuditStats{
|
||||
ByType: make(map[audit.EventType]int),
|
||||
BySeverity: make(map[audit.EventSeverity]int),
|
||||
ByStatus: make(map[audit.EventStatus]int),
|
||||
}
|
||||
|
||||
// Build base conditions
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
argIndex := 1
|
||||
|
||||
if len(filter.EventTypes) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("type = ANY($%d)", argIndex))
|
||||
typeStrings := make([]string, len(filter.EventTypes))
|
||||
for i, t := range filter.EventTypes {
|
||||
typeStrings[i] = string(t)
|
||||
}
|
||||
args = append(args, pq.Array(typeStrings))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.TenantID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
|
||||
args = append(args, *filter.TenantID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
|
||||
args = append(args, *filter.StartTime)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
|
||||
args = append(args, *filter.EndTime)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
|
||||
// Get total count
|
||||
totalQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
|
||||
err := db.QueryRowContext(ctx, totalQuery, args...).Scan(&stats.TotalEvents)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get total event count: %w", err)
|
||||
}
|
||||
|
||||
// Get stats by type
|
||||
typeQuery := fmt.Sprintf(`
|
||||
SELECT type, COUNT(*)
|
||||
FROM audit_events %s
|
||||
GROUP BY type
|
||||
ORDER BY COUNT(*) DESC
|
||||
`, whereClause)
|
||||
|
||||
rows, err := db.QueryContext(ctx, typeQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get type stats: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var eventType string
|
||||
var count int
|
||||
if err := rows.Scan(&eventType, &count); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan type stats: %w", err)
|
||||
}
|
||||
stats.ByType[audit.EventType(eventType)] = count
|
||||
}
|
||||
|
||||
// Get stats by severity
|
||||
severityQuery := fmt.Sprintf(`
|
||||
SELECT severity, COUNT(*)
|
||||
FROM audit_events %s
|
||||
GROUP BY severity
|
||||
ORDER BY COUNT(*) DESC
|
||||
`, whereClause)
|
||||
|
||||
rows, err = db.QueryContext(ctx, severityQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get severity stats: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var severity string
|
||||
var count int
|
||||
if err := rows.Scan(&severity, &count); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan severity stats: %w", err)
|
||||
}
|
||||
stats.BySeverity[audit.EventSeverity(severity)] = count
|
||||
}
|
||||
|
||||
// Get stats by status
|
||||
statusQuery := fmt.Sprintf(`
|
||||
SELECT status, COUNT(*)
|
||||
FROM audit_events %s
|
||||
GROUP BY status
|
||||
ORDER BY COUNT(*) DESC
|
||||
`, whereClause)
|
||||
|
||||
rows, err = db.QueryContext(ctx, statusQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get status stats: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var status string
|
||||
var count int
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan status stats: %w", err)
|
||||
}
|
||||
stats.ByStatus[audit.EventStatus(status)] = count
|
||||
}
|
||||
|
||||
// Get time-based stats if requested
|
||||
if filter.GroupBy != "" {
|
||||
stats.ByTime = make(map[string]int)
|
||||
|
||||
var timeFormat string
|
||||
switch filter.GroupBy {
|
||||
case "hour":
|
||||
timeFormat = "YYYY-MM-DD HH24:00"
|
||||
case "day":
|
||||
timeFormat = "YYYY-MM-DD"
|
||||
default:
|
||||
timeFormat = "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
timeQuery := fmt.Sprintf(`
|
||||
SELECT TO_CHAR(timestamp, '%s') as time_group, COUNT(*)
|
||||
FROM audit_events %s
|
||||
GROUP BY time_group
|
||||
ORDER BY time_group DESC
|
||||
`, timeFormat, whereClause)
|
||||
|
||||
rows, err = db.QueryContext(ctx, timeQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get time stats: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var timeGroup string
|
||||
var count int
|
||||
if err := rows.Scan(&timeGroup, &count); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan time stats: %w", err)
|
||||
}
|
||||
stats.ByTime[timeGroup] = count
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// DeleteOldEvents removes audit events older than the specified time
|
||||
func (r *AuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) {
|
||||
query := `DELETE FROM audit_events WHERE timestamp < $1`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
result, err := db.ExecContext(ctx, query, olderThan)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete old audit events: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a specific audit event by its ID
|
||||
func (r *AuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) {
|
||||
query := `
|
||||
SELECT id, type, severity, status, timestamp,
|
||||
actor_id, actor_type, actor_ip, user_agent, tenant_id,
|
||||
resource_id, resource_type, action, description, details,
|
||||
request_id, session_id, tags, metadata
|
||||
FROM audit_events
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, eventID)
|
||||
|
||||
event, err := r.scanAuditEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("audit event with ID '%s' not found", eventID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get audit event: %w", err)
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// GetByRequestID retrieves all audit events for a specific request
|
||||
func (r *AuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) {
|
||||
query := `
|
||||
SELECT id, type, severity, status, timestamp,
|
||||
actor_id, actor_type, actor_ip, user_agent, tenant_id,
|
||||
resource_id, resource_type, action, description, details,
|
||||
request_id, session_id, tags, metadata
|
||||
FROM audit_events
|
||||
WHERE request_id = $1
|
||||
ORDER BY timestamp ASC
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit events by request ID: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*audit.AuditEvent
|
||||
for rows.Next() {
|
||||
event, err := r.scanAuditEvent(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit event: %w", err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// GetBySession retrieves all audit events for a specific session
|
||||
func (r *AuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) {
|
||||
query := `
|
||||
SELECT id, type, severity, status, timestamp,
|
||||
actor_id, actor_type, actor_ip, user_agent, tenant_id,
|
||||
resource_id, resource_type, action, description, details,
|
||||
request_id, session_id, tags, metadata
|
||||
FROM audit_events
|
||||
WHERE session_id = $1
|
||||
ORDER BY timestamp ASC
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit events by session ID: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*audit.AuditEvent
|
||||
for rows.Next() {
|
||||
event, err := r.scanAuditEvent(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit event: %w", err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// GetByActor retrieves audit events for a specific actor
|
||||
func (r *AuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, type, severity, status, timestamp,
|
||||
actor_id, actor_type, actor_ip, user_agent, tenant_id,
|
||||
resource_id, resource_type, action, description, details,
|
||||
request_id, session_id, tags, metadata
|
||||
FROM audit_events
|
||||
WHERE actor_id = $1
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, actorID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit events by actor: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*audit.AuditEvent
|
||||
for rows.Next() {
|
||||
event, err := r.scanAuditEvent(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit event: %w", err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// GetByResource retrieves audit events for a specific resource
|
||||
func (r *AuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, type, severity, status, timestamp,
|
||||
actor_id, actor_type, actor_ip, user_agent, tenant_id,
|
||||
resource_id, resource_type, action, description, details,
|
||||
request_id, session_id, tags, metadata
|
||||
FROM audit_events
|
||||
WHERE resource_type = $1 AND resource_id = $2
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT $3 OFFSET $4
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, resourceType, resourceID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit events by resource: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*audit.AuditEvent
|
||||
for rows.Next() {
|
||||
event, err := r.scanAuditEvent(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit event: %w", err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// scanAuditEvent scans a database row into an AuditEvent struct
|
||||
func (r *AuditRepository) scanAuditEvent(row interface{}) (*audit.AuditEvent, error) {
|
||||
event := &audit.AuditEvent{}
|
||||
|
||||
var typeStr, severityStr, statusStr string
|
||||
var actorID, actorType, actorIP, userAgent sql.NullString
|
||||
var tenantID *uuid.UUID
|
||||
var resourceID, resourceType sql.NullString
|
||||
var detailsJSON, metadataJSON string
|
||||
var requestID, sessionID sql.NullString
|
||||
var tags pq.StringArray
|
||||
|
||||
var scanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
switch v := row.(type) {
|
||||
case *sql.Row:
|
||||
scanner = v
|
||||
case *sql.Rows:
|
||||
scanner = v
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid row type")
|
||||
}
|
||||
|
||||
err := scanner.Scan(
|
||||
&event.ID,
|
||||
&typeStr,
|
||||
&severityStr,
|
||||
&statusStr,
|
||||
&event.Timestamp,
|
||||
&actorID,
|
||||
&actorType,
|
||||
&actorIP,
|
||||
&userAgent,
|
||||
&tenantID,
|
||||
&resourceID,
|
||||
&resourceType,
|
||||
&event.Action,
|
||||
&event.Description,
|
||||
&detailsJSON,
|
||||
&requestID,
|
||||
&sessionID,
|
||||
&tags,
|
||||
&metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert string enums to types
|
||||
event.Type = audit.EventType(typeStr)
|
||||
event.Severity = audit.EventSeverity(severityStr)
|
||||
event.Status = audit.EventStatus(statusStr)
|
||||
|
||||
// Handle nullable fields
|
||||
if actorID.Valid {
|
||||
event.ActorID = actorID.String
|
||||
}
|
||||
if actorType.Valid {
|
||||
event.ActorType = actorType.String
|
||||
}
|
||||
if actorIP.Valid {
|
||||
event.ActorIP = actorIP.String
|
||||
}
|
||||
if userAgent.Valid {
|
||||
event.UserAgent = userAgent.String
|
||||
}
|
||||
if tenantID != nil {
|
||||
event.TenantID = tenantID
|
||||
}
|
||||
if resourceID.Valid {
|
||||
event.ResourceID = resourceID.String
|
||||
}
|
||||
if resourceType.Valid {
|
||||
event.ResourceType = resourceType.String
|
||||
}
|
||||
if requestID.Valid {
|
||||
event.RequestID = requestID.String
|
||||
}
|
||||
if sessionID.Valid {
|
||||
event.SessionID = sessionID.String
|
||||
}
|
||||
|
||||
// Convert tags
|
||||
event.Tags = []string(tags)
|
||||
|
||||
// Parse JSON fields
|
||||
if detailsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(detailsJSON), &event.Details); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal details JSON: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if metadataJSON != "" {
|
||||
if err := json.Unmarshal([]byte(metadataJSON), &event.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal metadata JSON: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
693
kms/internal/repository/postgres/permission_repository.go
Normal file
693
kms/internal/repository/postgres/permission_repository.go
Normal file
@ -0,0 +1,693 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// PermissionRepository implements the PermissionRepository interface for PostgreSQL
|
||||
type PermissionRepository struct {
|
||||
db repository.DatabaseProvider
|
||||
}
|
||||
|
||||
// NewPermissionRepository creates a new PostgreSQL permission repository
|
||||
func NewPermissionRepository(db repository.DatabaseProvider) repository.PermissionRepository {
|
||||
return &PermissionRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateAvailablePermission creates a new available permission
|
||||
func (r *PermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error {
|
||||
query := `
|
||||
INSERT INTO available_permissions (
|
||||
id, scope, name, description, category, parent_scope,
|
||||
is_system, created_by, updated_by, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
now := time.Now()
|
||||
|
||||
if permission.ID == uuid.Nil {
|
||||
permission.ID = uuid.New()
|
||||
}
|
||||
|
||||
_, err := db.ExecContext(ctx, query,
|
||||
permission.ID,
|
||||
permission.Scope,
|
||||
permission.Name,
|
||||
permission.Description,
|
||||
permission.Category,
|
||||
permission.ParentScope,
|
||||
permission.IsSystem,
|
||||
permission.CreatedBy,
|
||||
permission.UpdatedBy,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create available permission: %w", err)
|
||||
}
|
||||
|
||||
permission.CreatedAt = now
|
||||
permission.UpdatedAt = now
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailablePermission retrieves an available permission by ID
|
||||
func (r *PermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) {
|
||||
query := `
|
||||
SELECT id, scope, name, description, category, parent_scope,
|
||||
is_system, created_at, created_by, updated_at, updated_by
|
||||
FROM available_permissions
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, permissionID)
|
||||
|
||||
permission := &domain.AvailablePermission{}
|
||||
err := row.Scan(
|
||||
&permission.ID,
|
||||
&permission.Scope,
|
||||
&permission.Name,
|
||||
&permission.Description,
|
||||
&permission.Category,
|
||||
&permission.ParentScope,
|
||||
&permission.IsSystem,
|
||||
&permission.CreatedAt,
|
||||
&permission.CreatedBy,
|
||||
&permission.UpdatedAt,
|
||||
&permission.UpdatedBy,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("permission with ID '%s' not found", permissionID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get available permission: %w", err)
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
|
||||
// GetAvailablePermissionByScope retrieves an available permission by scope
|
||||
func (r *PermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) {
|
||||
query := `
|
||||
SELECT id, scope, name, description, category, parent_scope,
|
||||
is_system, created_at, created_by, updated_at, updated_by
|
||||
FROM available_permissions
|
||||
WHERE scope = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, scope)
|
||||
|
||||
permission := &domain.AvailablePermission{}
|
||||
err := row.Scan(
|
||||
&permission.ID,
|
||||
&permission.Scope,
|
||||
&permission.Name,
|
||||
&permission.Description,
|
||||
&permission.Category,
|
||||
&permission.ParentScope,
|
||||
&permission.IsSystem,
|
||||
&permission.CreatedAt,
|
||||
&permission.CreatedBy,
|
||||
&permission.UpdatedAt,
|
||||
&permission.UpdatedBy,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("permission with scope '%s' not found", scope)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get available permission by scope: %w", err)
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
|
||||
// ListAvailablePermissions retrieves available permissions with pagination and filtering
|
||||
func (r *PermissionRepository) ListAvailablePermissions(ctx context.Context, category string, includeSystem bool, limit, offset int) ([]*domain.AvailablePermission, error) {
|
||||
var args []interface{}
|
||||
var whereClauses []string
|
||||
argIndex := 1
|
||||
|
||||
// Build WHERE clause based on filters
|
||||
if category != "" {
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("category = $%d", argIndex))
|
||||
args = append(args, category)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if !includeSystem {
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("is_system = $%d", argIndex))
|
||||
args = append(args, false)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(whereClauses) > 0 {
|
||||
whereClause = "WHERE " + fmt.Sprintf("%s", whereClauses[0])
|
||||
for i := 1; i < len(whereClauses); i++ {
|
||||
whereClause += " AND " + whereClauses[i]
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, scope, name, description, category, parent_scope,
|
||||
is_system, created_at, created_by, updated_at, updated_by
|
||||
FROM available_permissions
|
||||
%s
|
||||
ORDER BY category, scope
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argIndex, argIndex+1)
|
||||
|
||||
args = append(args, limit, offset)
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list available permissions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []*domain.AvailablePermission
|
||||
for rows.Next() {
|
||||
permission := &domain.AvailablePermission{}
|
||||
err := rows.Scan(
|
||||
&permission.ID,
|
||||
&permission.Scope,
|
||||
&permission.Name,
|
||||
&permission.Description,
|
||||
&permission.Category,
|
||||
&permission.ParentScope,
|
||||
&permission.IsSystem,
|
||||
&permission.CreatedAt,
|
||||
&permission.CreatedBy,
|
||||
&permission.UpdatedAt,
|
||||
&permission.UpdatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan available permission: %w", err)
|
||||
}
|
||||
permissions = append(permissions, permission)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate available permissions: %w", err)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// UpdateAvailablePermission updates an available permission
|
||||
func (r *PermissionRepository) UpdateAvailablePermission(ctx context.Context, permissionID uuid.UUID, permission *domain.AvailablePermission) error {
|
||||
query := `
|
||||
UPDATE available_permissions
|
||||
SET scope = $2, name = $3, description = $4, category = $5,
|
||||
parent_scope = $6, is_system = $7, updated_by = $8, updated_at = $9
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
now := time.Now()
|
||||
|
||||
result, err := db.ExecContext(ctx, query,
|
||||
permissionID,
|
||||
permission.Scope,
|
||||
permission.Name,
|
||||
permission.Description,
|
||||
permission.Category,
|
||||
permission.ParentScope,
|
||||
permission.IsSystem,
|
||||
permission.UpdatedBy,
|
||||
now,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update available permission: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("permission with ID %s not found", permissionID)
|
||||
}
|
||||
|
||||
permission.UpdatedAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAvailablePermission deletes an available permission
|
||||
func (r *PermissionRepository) DeleteAvailablePermission(ctx context.Context, permissionID uuid.UUID) error {
|
||||
// First check if the permission has any child permissions
|
||||
checkChildrenQuery := `
|
||||
SELECT COUNT(*) FROM available_permissions
|
||||
WHERE parent_scope = (SELECT scope FROM available_permissions WHERE id = $1)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
var childCount int
|
||||
err := db.QueryRowContext(ctx, checkChildrenQuery, permissionID).Scan(&childCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for child permissions: %w", err)
|
||||
}
|
||||
|
||||
if childCount > 0 {
|
||||
return fmt.Errorf("cannot delete permission: it has %d child permissions", childCount)
|
||||
}
|
||||
|
||||
// Check if the permission is granted to any tokens
|
||||
checkGrantsQuery := `
|
||||
SELECT COUNT(*) FROM granted_permissions
|
||||
WHERE permission_id = $1 AND revoked = false
|
||||
`
|
||||
|
||||
var grantCount int
|
||||
err = db.QueryRowContext(ctx, checkGrantsQuery, permissionID).Scan(&grantCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for active grants: %w", err)
|
||||
}
|
||||
|
||||
if grantCount > 0 {
|
||||
return fmt.Errorf("cannot delete permission: it is currently granted to %d tokens", grantCount)
|
||||
}
|
||||
|
||||
// Delete the permission
|
||||
deleteQuery := `DELETE FROM available_permissions WHERE id = $1`
|
||||
result, err := db.ExecContext(ctx, deleteQuery, permissionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete available permission: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("permission with ID %s not found", permissionID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePermissionScopes checks if all given scopes exist and are valid
|
||||
func (r *PermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) {
|
||||
if len(scopes) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT scope
|
||||
FROM available_permissions
|
||||
WHERE scope = ANY($1)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, pq.Array(scopes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate permission scopes: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
validScopes := make(map[string]bool)
|
||||
for rows.Next() {
|
||||
var scope string
|
||||
if err := rows.Scan(&scope); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan scope: %w", err)
|
||||
}
|
||||
validScopes[scope] = true
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating scopes: %w", err)
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, scope := range scopes {
|
||||
if validScopes[scope] {
|
||||
result = append(result, scope)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetPermissionHierarchy returns all parent and child permissions for given scopes
|
||||
func (r *PermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {
|
||||
if len(scopes) == 0 {
|
||||
return []*domain.AvailablePermission{}, nil
|
||||
}
|
||||
|
||||
// Use recursive CTE to get full hierarchy
|
||||
query := `
|
||||
WITH RECURSIVE permission_hierarchy AS (
|
||||
-- Base case: get permissions matching the input scopes
|
||||
SELECT id, scope, name, description, category, parent_scope,
|
||||
is_system, created_at, created_by, updated_at, updated_by, 0 as level
|
||||
FROM available_permissions
|
||||
WHERE scope = ANY($1)
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- Recursive case: get all parents and children
|
||||
SELECT ap.id, ap.scope, ap.name, ap.description, ap.category, ap.parent_scope,
|
||||
ap.is_system, ap.created_at, ap.created_by, ap.updated_at, ap.updated_by,
|
||||
ph.level + 1 as level
|
||||
FROM available_permissions ap
|
||||
JOIN permission_hierarchy ph ON (
|
||||
-- Get parents (where ap.scope = ph.parent_scope)
|
||||
ap.scope = ph.parent_scope
|
||||
OR
|
||||
-- Get children (where ap.parent_scope = ph.scope)
|
||||
ap.parent_scope = ph.scope
|
||||
)
|
||||
WHERE ph.level < 5 -- Prevent infinite recursion
|
||||
)
|
||||
SELECT DISTINCT id, scope, name, description, category, parent_scope,
|
||||
is_system, created_at, created_by, updated_at, updated_by
|
||||
FROM permission_hierarchy
|
||||
ORDER BY scope
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, pq.Array(scopes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get permission hierarchy: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []*domain.AvailablePermission
|
||||
for rows.Next() {
|
||||
permission := &domain.AvailablePermission{}
|
||||
err := rows.Scan(
|
||||
&permission.ID,
|
||||
&permission.Scope,
|
||||
&permission.Name,
|
||||
&permission.Description,
|
||||
&permission.Category,
|
||||
&permission.ParentScope,
|
||||
&permission.IsSystem,
|
||||
&permission.CreatedAt,
|
||||
&permission.CreatedBy,
|
||||
&permission.UpdatedAt,
|
||||
&permission.UpdatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan permission hierarchy: %w", err)
|
||||
}
|
||||
permissions = append(permissions, permission)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate permission hierarchy: %w", err)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// GrantedPermissionRepository implements the GrantedPermissionRepository interface for PostgreSQL
|
||||
type GrantedPermissionRepository struct {
|
||||
db repository.DatabaseProvider
|
||||
}
|
||||
|
||||
// NewGrantedPermissionRepository creates a new PostgreSQL granted permission repository
|
||||
func NewGrantedPermissionRepository(db repository.DatabaseProvider) repository.GrantedPermissionRepository {
|
||||
return &GrantedPermissionRepository{db: db}
|
||||
}
|
||||
|
||||
// GrantPermissions grants multiple permissions to a token
|
||||
func (r *GrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error {
|
||||
if len(grants) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
query := `
|
||||
INSERT INTO granted_permissions (
|
||||
id, token_type, token_id, permission_id, scope, created_by, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (token_type, token_id, permission_id) DO NOTHING
|
||||
`
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare statement: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
now := time.Now()
|
||||
for _, grant := range grants {
|
||||
if grant.ID == uuid.Nil {
|
||||
grant.ID = uuid.New()
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx,
|
||||
grant.ID,
|
||||
string(grant.TokenType),
|
||||
grant.TokenID,
|
||||
grant.PermissionID,
|
||||
grant.Scope,
|
||||
grant.CreatedBy,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to grant permission: %w", err)
|
||||
}
|
||||
|
||||
grant.CreatedAt = now
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGrantedPermissions retrieves all granted permissions for a token
|
||||
func (r *GrantedPermissionRepository) GetGrantedPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]*domain.GrantedPermission, error) {
|
||||
query := `
|
||||
SELECT id, token_type, token_id, permission_id, scope, created_at, created_by, revoked
|
||||
FROM granted_permissions
|
||||
WHERE token_type = $1 AND token_id = $2 AND revoked = false
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query granted permissions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []*domain.GrantedPermission
|
||||
for rows.Next() {
|
||||
perm := &domain.GrantedPermission{}
|
||||
var tokenTypeStr string
|
||||
|
||||
err := rows.Scan(
|
||||
&perm.ID,
|
||||
&tokenTypeStr,
|
||||
&perm.TokenID,
|
||||
&perm.PermissionID,
|
||||
&perm.Scope,
|
||||
&perm.CreatedAt,
|
||||
&perm.CreatedBy,
|
||||
&perm.Revoked,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan granted permission: %w", err)
|
||||
}
|
||||
|
||||
perm.TokenType = domain.TokenType(tokenTypeStr)
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating granted permissions: %w", err)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// GetGrantedPermissionScopes retrieves only the scopes for a token (more efficient)
|
||||
func (r *GrantedPermissionRepository) GetGrantedPermissionScopes(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]string, error) {
|
||||
query := `
|
||||
SELECT scope
|
||||
FROM granted_permissions
|
||||
WHERE token_type = $1 AND token_id = $2 AND revoked = false
|
||||
ORDER BY scope ASC
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query granted permission scopes: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var scopes []string
|
||||
for rows.Next() {
|
||||
var scope string
|
||||
if err := rows.Scan(&scope); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan permission scope: %w", err)
|
||||
}
|
||||
scopes = append(scopes, scope)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating permission scopes: %w", err)
|
||||
}
|
||||
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
// RevokePermission revokes a specific permission from a token
|
||||
func (r *GrantedPermissionRepository) RevokePermission(ctx context.Context, grantID uuid.UUID, revokedBy string) error {
|
||||
query := `
|
||||
UPDATE granted_permissions
|
||||
SET revoked = true, revoked_by = $2, revoked_at = $3
|
||||
WHERE id = $1 AND revoked = false
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
now := time.Now()
|
||||
|
||||
result, err := db.ExecContext(ctx, query, grantID, revokedBy, now)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke permission: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("permission grant with ID %s not found or already revoked", grantID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllPermissions revokes all permissions from a token
|
||||
func (r *GrantedPermissionRepository) RevokeAllPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, revokedBy string) error {
|
||||
query := `
|
||||
UPDATE granted_permissions
|
||||
SET revoked = true, revoked_by = $3, revoked_at = $4
|
||||
WHERE token_type = $1 AND token_id = $2 AND revoked = false
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
now := time.Now()
|
||||
|
||||
result, err := db.ExecContext(ctx, query, tokenType, tokenID, revokedBy, now)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke all permissions: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
// Note: rowsAffected being 0 is not necessarily an error here -
|
||||
// the token might not have had any active permissions
|
||||
_ = rowsAffected
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasPermission checks if a token has a specific permission
|
||||
func (r *GrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) {
|
||||
query := `
|
||||
SELECT 1
|
||||
FROM granted_permissions gp
|
||||
JOIN available_permissions ap ON gp.permission_id = ap.id
|
||||
WHERE gp.token_type = $1
|
||||
AND gp.token_id = $2
|
||||
AND gp.scope = $3
|
||||
AND gp.revoked = false
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
var exists int
|
||||
err := db.QueryRowContext(ctx, query, string(tokenType), tokenID, scope).Scan(&exists)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check permission: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// HasAnyPermission checks if a token has any of the specified permissions
|
||||
func (r *GrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) {
|
||||
if len(scopes) == 0 {
|
||||
return make(map[string]bool), nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT gp.scope
|
||||
FROM granted_permissions gp
|
||||
JOIN available_permissions ap ON gp.permission_id = ap.id
|
||||
WHERE gp.token_type = $1
|
||||
AND gp.token_id = $2
|
||||
AND gp.scope = ANY($3)
|
||||
AND gp.revoked = false
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID, pq.Array(scopes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check permissions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[string]bool)
|
||||
// Initialize all scopes as false
|
||||
for _, scope := range scopes {
|
||||
result[scope] = false
|
||||
}
|
||||
|
||||
// Mark found permissions as true
|
||||
for rows.Next() {
|
||||
var scope string
|
||||
if err := rows.Scan(&scope); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan permission scope: %w", err)
|
||||
}
|
||||
result[scope] = true
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating permission results: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
624
kms/internal/repository/postgres/session_repository.go
Normal file
624
kms/internal/repository/postgres/session_repository.go
Normal file
@ -0,0 +1,624 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// sessionRepository implements the SessionRepository interface
|
||||
type sessionRepository struct {
|
||||
db *sqlx.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewSessionRepository creates a new session repository
|
||||
func NewSessionRepository(db *sqlx.DB, logger *zap.Logger) repository.SessionRepository {
|
||||
return &sessionRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new user session
|
||||
func (r *sessionRepository) Create(ctx context.Context, session *domain.UserSession) error {
|
||||
r.logger.Debug("Creating new session",
|
||||
zap.String("user_id", session.UserID),
|
||||
zap.String("app_id", session.AppID),
|
||||
zap.String("session_type", string(session.SessionType)))
|
||||
|
||||
// Generate ID if not provided
|
||||
if session.ID == uuid.Nil {
|
||||
session.ID = uuid.New()
|
||||
}
|
||||
|
||||
// Set timestamps
|
||||
now := time.Now()
|
||||
session.CreatedAt = now
|
||||
session.UpdatedAt = now
|
||||
session.LastActivity = now
|
||||
|
||||
// Serialize metadata
|
||||
metadataJSON, err := json.Marshal(session.Metadata)
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to serialize session metadata").WithInternal(err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO user_sessions (
|
||||
id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at, metadata
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
|
||||
)`
|
||||
|
||||
_, err = r.db.ExecContext(ctx, query,
|
||||
session.ID,
|
||||
session.UserID,
|
||||
session.AppID,
|
||||
session.SessionType,
|
||||
session.Status,
|
||||
session.AccessToken,
|
||||
session.RefreshToken,
|
||||
session.IDToken,
|
||||
session.IPAddress,
|
||||
session.UserAgent,
|
||||
session.LastActivity,
|
||||
session.ExpiresAt,
|
||||
session.CreatedAt,
|
||||
session.UpdatedAt,
|
||||
metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to create session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to create session").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Session created successfully", zap.String("session_id", session.ID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a session by its ID
|
||||
func (r *sessionRepository) GetByID(ctx context.Context, sessionID uuid.UUID) (*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting session by ID", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE id = $1`
|
||||
|
||||
var session domain.UserSession
|
||||
var metadataJSON []byte
|
||||
var revokedAt sql.NullTime
|
||||
var revokedBy sql.NullString
|
||||
|
||||
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.AppID,
|
||||
&session.SessionType,
|
||||
&session.Status,
|
||||
&session.AccessToken,
|
||||
&session.RefreshToken,
|
||||
&session.IDToken,
|
||||
&session.IPAddress,
|
||||
&session.UserAgent,
|
||||
&session.LastActivity,
|
||||
&session.ExpiresAt,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&revokedAt,
|
||||
&revokedBy,
|
||||
&metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
r.logger.Error("Failed to get session by ID", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to retrieve session").WithInternal(err)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if revokedAt.Valid {
|
||||
session.RevokedAt = &revokedAt.Time
|
||||
}
|
||||
if revokedBy.Valid {
|
||||
session.RevokedBy = &revokedBy.String
|
||||
}
|
||||
|
||||
// Deserialize metadata
|
||||
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
|
||||
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
|
||||
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
|
||||
}
|
||||
|
||||
r.logger.Debug("Session retrieved successfully", zap.String("session_id", sessionID.String()))
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetByUserID retrieves all sessions for a user
|
||||
func (r *sessionRepository) GetByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting sessions by user ID", zap.String("user_id", userID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID)
|
||||
}
|
||||
|
||||
// GetByUserAndApp retrieves sessions for a specific user and application
|
||||
func (r *sessionRepository) GetByUserAndApp(ctx context.Context, userID, appID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting sessions by user and app",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND app_id = $2
|
||||
ORDER BY created_at DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID, appID)
|
||||
}
|
||||
|
||||
// GetActiveByUserID retrieves all active sessions for a user
|
||||
func (r *sessionRepository) GetActiveByUserID(ctx context.Context, userID string) ([]*domain.UserSession, error) {
|
||||
r.logger.Debug("Getting active sessions by user ID", zap.String("user_id", userID))
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND status = $2 AND expires_at > NOW()
|
||||
ORDER BY last_activity DESC`
|
||||
|
||||
return r.scanSessions(ctx, query, userID, domain.SessionStatusActive)
|
||||
}
|
||||
|
||||
// List retrieves sessions with filtering and pagination
|
||||
func (r *sessionRepository) List(ctx context.Context, req *domain.SessionListRequest) (*domain.SessionListResponse, error) {
|
||||
r.logger.Debug("Listing sessions with filters",
|
||||
zap.String("user_id", req.UserID),
|
||||
zap.String("app_id", req.AppID),
|
||||
zap.Int("limit", req.Limit),
|
||||
zap.Int("offset", req.Offset))
|
||||
|
||||
// Build WHERE clause dynamically
|
||||
whereClause := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if req.UserID != "" {
|
||||
whereClause += fmt.Sprintf(" AND user_id = $%d", argIndex)
|
||||
args = append(args, req.UserID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.AppID != "" {
|
||||
whereClause += fmt.Sprintf(" AND app_id = $%d", argIndex)
|
||||
args = append(args, req.AppID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
whereClause += fmt.Sprintf(" AND status = $%d", argIndex)
|
||||
args = append(args, *req.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.SessionType != nil {
|
||||
whereClause += fmt.Sprintf(" AND session_type = $%d", argIndex)
|
||||
args = append(args, *req.SessionType)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if req.TenantID != "" {
|
||||
whereClause += fmt.Sprintf(" AND metadata->>'tenant_id' = $%d", argIndex)
|
||||
args = append(args, req.TenantID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM user_sessions %s", whereClause)
|
||||
var total int
|
||||
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get session count", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to count sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
// Get sessions with pagination
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, user_id, app_id, session_type, status, access_token,
|
||||
refresh_token, id_token, ip_address, user_agent,
|
||||
last_activity, expires_at, created_at, updated_at,
|
||||
revoked_at, revoked_by, metadata
|
||||
FROM user_sessions
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d`, whereClause, argIndex, argIndex+1)
|
||||
|
||||
args = append(args, req.Limit, req.Offset)
|
||||
sessions, err := r.scanSessions(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.SessionListResponse{
|
||||
Sessions: sessions,
|
||||
Total: total,
|
||||
Limit: req.Limit,
|
||||
Offset: req.Offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Update updates an existing session
|
||||
func (r *sessionRepository) Update(ctx context.Context, sessionID uuid.UUID, updates *domain.UpdateSessionRequest) error {
|
||||
r.logger.Debug("Updating session", zap.String("session_id", sessionID.String()))
|
||||
|
||||
// Build UPDATE clause dynamically
|
||||
setParts := []string{"updated_at = NOW()"}
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if updates.Status != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("status = $%d", argIndex))
|
||||
args = append(args, *updates.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.LastActivity != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("last_activity = $%d", argIndex))
|
||||
args = append(args, *updates.LastActivity)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.ExpiresAt != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("expires_at = $%d", argIndex))
|
||||
args = append(args, *updates.ExpiresAt)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.IPAddress != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("ip_address = $%d", argIndex))
|
||||
args = append(args, *updates.IPAddress)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.UserAgent != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("user_agent = $%d", argIndex))
|
||||
args = append(args, *updates.UserAgent)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if len(setParts) == 1 {
|
||||
return errors.NewValidationError("No fields to update")
|
||||
}
|
||||
|
||||
// Build the complete query
|
||||
setClause := fmt.Sprintf("%s", setParts[0])
|
||||
for i := 1; i < len(setParts); i++ {
|
||||
setClause += fmt.Sprintf(", %s", setParts[i])
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE user_sessions SET %s WHERE id = $%d", setClause, argIndex)
|
||||
args = append(args, sessionID)
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to update session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to update session").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
r.logger.Debug("Session updated successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a session
|
||||
func (r *sessionRepository) UpdateActivity(ctx context.Context, sessionID uuid.UUID) error {
|
||||
r.logger.Debug("Updating session activity", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `UPDATE user_sessions SET last_activity = NOW(), updated_at = NOW() WHERE id = $1`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, sessionID)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to update session activity", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to update session activity").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke revokes a session
|
||||
func (r *sessionRepository) Revoke(ctx context.Context, sessionID uuid.UUID, revokedBy string) error {
|
||||
r.logger.Debug("Revoking session",
|
||||
zap.String("session_id", sessionID.String()),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE id = $3`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, sessionID)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke session", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke session").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return errors.NewNotFoundError("Session not found")
|
||||
}
|
||||
|
||||
r.logger.Debug("Session revoked successfully", zap.String("session_id", sessionID.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllByUser revokes all sessions for a user
|
||||
func (r *sessionRepository) RevokeAllByUser(ctx context.Context, userID string, revokedBy string) error {
|
||||
r.logger.Debug("Revoking all sessions for user",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE user_id = $3 AND status = $4`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke user sessions", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke user sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("User sessions revoked",
|
||||
zap.String("user_id", userID),
|
||||
zap.Int64("sessions_revoked", rowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllByUserAndApp revokes all sessions for a user and application
|
||||
func (r *sessionRepository) RevokeAllByUserAndApp(ctx context.Context, userID, appID string, revokedBy string) error {
|
||||
r.logger.Debug("Revoking all sessions for user and app",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.String("revoked_by", revokedBy))
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, revoked_at = NOW(), revoked_by = $2, updated_at = NOW()
|
||||
WHERE user_id = $3 AND app_id = $4 AND status = $5`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusRevoked, revokedBy, userID, appID, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to revoke user app sessions", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke user app sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("User app sessions revoked",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.Int64("sessions_revoked", rowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpireOldSessions marks expired sessions as expired
|
||||
func (r *sessionRepository) ExpireOldSessions(ctx context.Context) (int, error) {
|
||||
r.logger.Debug("Expiring old sessions")
|
||||
|
||||
query := `
|
||||
UPDATE user_sessions
|
||||
SET status = $1, updated_at = NOW()
|
||||
WHERE expires_at < NOW() AND status = $2`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, domain.SessionStatusActive)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to expire old sessions", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to expire old sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Old sessions expired", zap.Int64("sessions_expired", rowsAffected))
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// DeleteExpiredSessions removes expired sessions older than the specified duration
|
||||
func (r *sessionRepository) DeleteExpiredSessions(ctx context.Context, olderThan time.Duration) (int, error) {
|
||||
r.logger.Debug("Deleting expired sessions", zap.Duration("older_than", olderThan))
|
||||
|
||||
cutoffTime := time.Now().Add(-olderThan)
|
||||
query := `DELETE FROM user_sessions WHERE status = $1 AND updated_at < $2`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, domain.SessionStatusExpired, cutoffTime)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete expired sessions", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to delete expired sessions").WithInternal(err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, errors.NewInternalError("Failed to get affected rows").WithInternal(err)
|
||||
}
|
||||
|
||||
r.logger.Debug("Expired sessions deleted", zap.Int64("sessions_deleted", rowsAffected))
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
// Exists checks if a session exists
|
||||
func (r *sessionRepository) Exists(ctx context.Context, sessionID uuid.UUID) (bool, error) {
|
||||
r.logger.Debug("Checking if session exists", zap.String("session_id", sessionID.String()))
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1)`
|
||||
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx, query, sessionID).Scan(&exists)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to check session existence", zap.Error(err))
|
||||
return false, errors.NewInternalError("Failed to check session existence").WithInternal(err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetSessionCount returns the total number of sessions for a user
|
||||
func (r *sessionRepository) GetSessionCount(ctx context.Context, userID string) (int, error) {
|
||||
r.logger.Debug("Getting session count for user", zap.String("user_id", userID))
|
||||
|
||||
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1`
|
||||
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, query, userID).Scan(&count)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get session count", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to get session count").WithInternal(err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions for a user
|
||||
func (r *sessionRepository) GetActiveSessionCount(ctx context.Context, userID string) (int, error) {
|
||||
r.logger.Debug("Getting active session count for user", zap.String("user_id", userID))
|
||||
|
||||
query := `SELECT COUNT(*) FROM user_sessions WHERE user_id = $1 AND status = $2 AND expires_at > NOW()`
|
||||
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, query, userID, domain.SessionStatusActive).Scan(&count)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get active session count", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to get active session count").WithInternal(err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// scanSessions is a helper method to scan multiple sessions from query results
|
||||
func (r *sessionRepository) scanSessions(ctx context.Context, query string, args ...interface{}) ([]*domain.UserSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to execute session query", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to retrieve sessions").WithInternal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []*domain.UserSession
|
||||
for rows.Next() {
|
||||
var session domain.UserSession
|
||||
var metadataJSON []byte
|
||||
var revokedAt sql.NullTime
|
||||
var revokedBy sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.AppID,
|
||||
&session.SessionType,
|
||||
&session.Status,
|
||||
&session.AccessToken,
|
||||
&session.RefreshToken,
|
||||
&session.IDToken,
|
||||
&session.IPAddress,
|
||||
&session.UserAgent,
|
||||
&session.LastActivity,
|
||||
&session.ExpiresAt,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&revokedAt,
|
||||
&revokedBy,
|
||||
&metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to scan session row", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to scan session data").WithInternal(err)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if revokedAt.Valid {
|
||||
session.RevokedAt = &revokedAt.Time
|
||||
}
|
||||
if revokedBy.Valid {
|
||||
session.RevokedBy = &revokedBy.String
|
||||
}
|
||||
|
||||
// Deserialize metadata
|
||||
if err := json.Unmarshal(metadataJSON, &session.Metadata); err != nil {
|
||||
r.logger.Warn("Failed to deserialize session metadata", zap.Error(err))
|
||||
session.Metadata = domain.SessionMetadata{} // Use empty metadata on error
|
||||
}
|
||||
|
||||
sessions = append(sessions, &session)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
r.logger.Error("Error iterating session rows", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to iterate session results").WithInternal(err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
290
kms/internal/repository/postgres/token_repository.go
Normal file
290
kms/internal/repository/postgres/token_repository.go
Normal file
@ -0,0 +1,290 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// StaticTokenRepository implements the StaticTokenRepository interface for PostgreSQL
|
||||
type StaticTokenRepository struct {
|
||||
db repository.DatabaseProvider
|
||||
}
|
||||
|
||||
// NewStaticTokenRepository creates a new PostgreSQL static token repository
|
||||
func NewStaticTokenRepository(db repository.DatabaseProvider) repository.StaticTokenRepository {
|
||||
return &StaticTokenRepository{db: db}
|
||||
}
|
||||
|
||||
// Create creates a new static token
|
||||
func (r *StaticTokenRepository) Create(ctx context.Context, token *domain.StaticToken) error {
|
||||
query := `
|
||||
INSERT INTO static_tokens (
|
||||
id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
now := time.Now()
|
||||
|
||||
_, err := db.ExecContext(ctx, query,
|
||||
token.ID,
|
||||
token.AppID,
|
||||
string(token.Owner.Type),
|
||||
token.Owner.Name,
|
||||
token.Owner.Owner,
|
||||
token.KeyHash,
|
||||
string(token.Type),
|
||||
now,
|
||||
now,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create static token: %w", err)
|
||||
}
|
||||
|
||||
token.CreatedAt = now
|
||||
token.UpdatedAt = now
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a static token by its ID
|
||||
func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, tokenID)
|
||||
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := row.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("static token with ID '%s' not found", tokenID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get static token: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetByKeyHash retrieves a static token by its key hash
|
||||
func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
WHERE key_hash = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, keyHash)
|
||||
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := row.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("static token with hash not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get static token by hash: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetByAppID retrieves all static tokens for an application
|
||||
func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
WHERE app_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, appID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []*domain.StaticToken
|
||||
for rows.Next() {
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// List retrieves static tokens with pagination
|
||||
func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, limit, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []*domain.StaticToken
|
||||
for rows.Next() {
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// Delete deletes a static token
|
||||
func (r *StaticTokenRepository) Delete(ctx context.Context, tokenID uuid.UUID) error {
|
||||
query := `DELETE FROM static_tokens WHERE id = $1`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
result, err := db.ExecContext(ctx, query, tokenID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete static token: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("static token with ID '%s' not found", tokenID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a static token exists
|
||||
func (r *StaticTokenRepository) Exists(ctx context.Context, tokenID uuid.UUID) (bool, error) {
|
||||
query := `SELECT 1 FROM static_tokens WHERE id = $1`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
var exists int
|
||||
err := db.QueryRowContext(ctx, query, tokenID).Scan(&exists)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check static token existence: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
Reference in New Issue
Block a user