Files
skybridge/test/mock_repositories.go
2025-08-25 21:28:14 -04:00

817 lines
19 KiB
Go

package test
import (
"context"
"fmt"
"sync"
"time"
"github.com/google/uuid"
"github.com/kms/api-key-service/internal/audit"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
)
// MockDatabaseProvider implements DatabaseProvider for testing
type MockDatabaseProvider struct {
mu sync.RWMutex
}
func NewMockDatabaseProvider() repository.DatabaseProvider {
return &MockDatabaseProvider{}
}
func (m *MockDatabaseProvider) GetDB() interface{} {
return m
}
func (m *MockDatabaseProvider) Ping(ctx context.Context) error {
return nil
}
func (m *MockDatabaseProvider) Close() error {
return nil
}
func (m *MockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
return &MockTransactionProvider{}, nil
}
func (m *MockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error {
return nil
}
// MockTransactionProvider implements TransactionProvider for testing
type MockTransactionProvider struct{}
func (m *MockTransactionProvider) Commit() error {
return nil
}
func (m *MockTransactionProvider) Rollback() error {
return nil
}
func (m *MockTransactionProvider) GetTx() interface{} {
return m
}
// MockApplicationRepository implements ApplicationRepository for testing
type MockApplicationRepository struct {
mu sync.RWMutex
applications map[string]*domain.Application
}
func NewMockApplicationRepository() repository.ApplicationRepository {
return &MockApplicationRepository{
applications: make(map[string]*domain.Application),
}
}
func (m *MockApplicationRepository) Create(ctx context.Context, app *domain.Application) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.applications[app.AppID]; exists {
return fmt.Errorf("application with ID '%s' already exists", app.AppID)
}
now := time.Now()
app.CreatedAt = now
app.UpdatedAt = now
// Make a copy to avoid reference issues
appCopy := *app
m.applications[app.AppID] = &appCopy
return nil
}
func (m *MockApplicationRepository) GetByID(ctx context.Context, appID string) (*domain.Application, error) {
m.mu.RLock()
defer m.mu.RUnlock()
app, exists := m.applications[appID]
if !exists {
return nil, fmt.Errorf("application with ID '%s' not found", appID)
}
// Return a copy to avoid reference issues
appCopy := *app
return &appCopy, nil
}
func (m *MockApplicationRepository) List(ctx context.Context, limit, offset int) ([]*domain.Application, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var apps []*domain.Application
i := 0
for _, app := range m.applications {
if i < offset {
i++
continue
}
if len(apps) >= limit {
break
}
// Return a copy to avoid reference issues
appCopy := *app
apps = append(apps, &appCopy)
i++
}
return apps, nil
}
func (m *MockApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) {
m.mu.Lock()
defer m.mu.Unlock()
app, exists := m.applications[appID]
if !exists {
return nil, fmt.Errorf("application with ID '%s' not found", appID)
}
// Apply updates
if updates.AppLink != nil {
app.AppLink = *updates.AppLink
}
if updates.Type != nil {
app.Type = *updates.Type
}
if updates.CallbackURL != nil {
app.CallbackURL = *updates.CallbackURL
}
if updates.HMACKey != nil {
app.HMACKey = *updates.HMACKey
}
if updates.TokenRenewalDuration != nil {
app.TokenRenewalDuration = *updates.TokenRenewalDuration
}
if updates.MaxTokenDuration != nil {
app.MaxTokenDuration = *updates.MaxTokenDuration
}
if updates.Owner != nil {
app.Owner = *updates.Owner
}
app.UpdatedAt = time.Now()
// Return a copy
appCopy := *app
return &appCopy, nil
}
func (m *MockApplicationRepository) Delete(ctx context.Context, appID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.applications[appID]; !exists {
return fmt.Errorf("application with ID '%s' not found", appID)
}
delete(m.applications, appID)
return nil
}
func (m *MockApplicationRepository) Exists(ctx context.Context, appID string) (bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
_, exists := m.applications[appID]
return exists, nil
}
// MockStaticTokenRepository implements StaticTokenRepository for testing
type MockStaticTokenRepository struct {
mu sync.RWMutex
tokens map[uuid.UUID]*domain.StaticToken
}
func NewMockStaticTokenRepository() repository.StaticTokenRepository {
return &MockStaticTokenRepository{
tokens: make(map[uuid.UUID]*domain.StaticToken),
}
}
func (m *MockStaticTokenRepository) Create(ctx context.Context, token *domain.StaticToken) error {
m.mu.Lock()
defer m.mu.Unlock()
if token.ID == uuid.Nil {
token.ID = uuid.New()
}
now := time.Now()
token.CreatedAt = now
token.UpdatedAt = now
// Make a copy
tokenCopy := *token
m.tokens[token.ID] = &tokenCopy
return nil
}
func (m *MockStaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
m.mu.RLock()
defer m.mu.RUnlock()
token, exists := m.tokens[tokenID]
if !exists {
return nil, fmt.Errorf("token with ID '%s' not found", tokenID)
}
tokenCopy := *token
return &tokenCopy, nil
}
func (m *MockStaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, token := range m.tokens {
if token.KeyHash == keyHash {
tokenCopy := *token
return &tokenCopy, nil
}
}
return nil, fmt.Errorf("token with key hash not found")
}
func (m *MockStaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var tokens []*domain.StaticToken
for _, token := range m.tokens {
if token.AppID == appID {
tokenCopy := *token
tokens = append(tokens, &tokenCopy)
}
}
return tokens, nil
}
func (m *MockStaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var tokens []*domain.StaticToken
i := 0
for _, token := range m.tokens {
if i < offset {
i++
continue
}
if len(tokens) >= limit {
break
}
tokenCopy := *token
tokens = append(tokens, &tokenCopy)
i++
}
return tokens, nil
}
func (m *MockStaticTokenRepository) Delete(ctx context.Context, tokenID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.tokens[tokenID]; !exists {
return fmt.Errorf("token with ID '%s' not found", tokenID)
}
delete(m.tokens, tokenID)
return nil
}
func (m *MockStaticTokenRepository) Exists(ctx context.Context, tokenID uuid.UUID) (bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
_, exists := m.tokens[tokenID]
return exists, nil
}
// MockPermissionRepository implements PermissionRepository for testing
type MockPermissionRepository struct {
mu sync.RWMutex
permissions map[uuid.UUID]*domain.AvailablePermission
scopeIndex map[string]uuid.UUID
}
func NewMockPermissionRepository() repository.PermissionRepository {
repo := &MockPermissionRepository{
permissions: make(map[uuid.UUID]*domain.AvailablePermission),
scopeIndex: make(map[string]uuid.UUID),
}
// Add some default permissions for testing
ctx := context.Background()
defaultPerms := []*domain.AvailablePermission{
{
ID: uuid.New(),
Scope: "repo.read",
Name: "Repository Read",
Description: "Read repository data",
Category: "repository",
IsSystem: false,
CreatedAt: time.Now(),
CreatedBy: "system",
UpdatedAt: time.Now(),
UpdatedBy: "system",
},
{
ID: uuid.New(),
Scope: "repo.write",
Name: "Repository Write",
Description: "Write to repositories",
Category: "repository",
IsSystem: false,
CreatedAt: time.Now(),
CreatedBy: "system",
UpdatedAt: time.Now(),
UpdatedBy: "system",
},
{
ID: uuid.New(),
Scope: "app.read",
Name: "Application Read",
Description: "Read application data",
Category: "application",
IsSystem: false,
CreatedAt: time.Now(),
CreatedBy: "system",
UpdatedAt: time.Now(),
UpdatedBy: "system",
},
}
for _, perm := range defaultPerms {
repo.CreateAvailablePermission(ctx, perm)
}
return repo
}
func (m *MockPermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error {
m.mu.Lock()
defer m.mu.Unlock()
if permission.ID == uuid.Nil {
permission.ID = uuid.New()
}
if _, exists := m.scopeIndex[permission.Scope]; exists {
return fmt.Errorf("permission with scope '%s' already exists", permission.Scope)
}
now := time.Now()
permission.CreatedAt = now
permission.UpdatedAt = now
permCopy := *permission
m.permissions[permission.ID] = &permCopy
m.scopeIndex[permission.Scope] = permission.ID
return nil
}
func (m *MockPermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) {
m.mu.RLock()
defer m.mu.RUnlock()
perm, exists := m.permissions[permissionID]
if !exists {
return nil, fmt.Errorf("permission with ID '%s' not found", permissionID)
}
permCopy := *perm
return &permCopy, nil
}
func (m *MockPermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) {
m.mu.RLock()
defer m.mu.RUnlock()
permID, exists := m.scopeIndex[scope]
if !exists {
return nil, fmt.Errorf("permission with scope '%s' not found", scope)
}
perm := m.permissions[permID]
permCopy := *perm
return &permCopy, nil
}
func (m *MockPermissionRepository) ListAvailablePermissions(ctx context.Context, category string, includeSystem bool, limit, offset int) ([]*domain.AvailablePermission, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var perms []*domain.AvailablePermission
i := 0
for _, perm := range m.permissions {
if category != "" && perm.Category != category {
continue
}
if !includeSystem && perm.IsSystem {
continue
}
if i < offset {
i++
continue
}
if len(perms) >= limit {
break
}
permCopy := *perm
perms = append(perms, &permCopy)
i++
}
return perms, nil
}
func (m *MockPermissionRepository) UpdateAvailablePermission(ctx context.Context, permissionID uuid.UUID, permission *domain.AvailablePermission) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.permissions[permissionID]; !exists {
return fmt.Errorf("permission with ID '%s' not found", permissionID)
}
permission.ID = permissionID
permission.UpdatedAt = time.Now()
permCopy := *permission
m.permissions[permissionID] = &permCopy
return nil
}
func (m *MockPermissionRepository) DeleteAvailablePermission(ctx context.Context, permissionID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
perm, exists := m.permissions[permissionID]
if !exists {
return fmt.Errorf("permission with ID '%s' not found", permissionID)
}
delete(m.permissions, permissionID)
delete(m.scopeIndex, perm.Scope)
return nil
}
func (m *MockPermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var valid []string
for _, scope := range scopes {
if _, exists := m.scopeIndex[scope]; exists {
valid = append(valid, scope)
}
}
return valid, nil
}
func (m *MockPermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var perms []*domain.AvailablePermission
for _, scope := range scopes {
if permID, exists := m.scopeIndex[scope]; exists {
perm := m.permissions[permID]
permCopy := *perm
perms = append(perms, &permCopy)
}
}
return perms, nil
}
// MockGrantedPermissionRepository implements GrantedPermissionRepository for testing
type MockGrantedPermissionRepository struct {
mu sync.RWMutex
grants map[uuid.UUID]*domain.GrantedPermission
}
func NewMockGrantedPermissionRepository() repository.GrantedPermissionRepository {
return &MockGrantedPermissionRepository{
grants: make(map[uuid.UUID]*domain.GrantedPermission),
}
}
func (m *MockGrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error {
m.mu.Lock()
defer m.mu.Unlock()
for _, grant := range grants {
if grant.ID == uuid.Nil {
grant.ID = uuid.New()
}
grant.CreatedAt = time.Now()
grantCopy := *grant
m.grants[grant.ID] = &grantCopy
}
return nil
}
func (m *MockGrantedPermissionRepository) GetGrantedPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]*domain.GrantedPermission, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var grants []*domain.GrantedPermission
for _, grant := range m.grants {
if grant.TokenType == tokenType && grant.TokenID == tokenID && !grant.Revoked {
grantCopy := *grant
grants = append(grants, &grantCopy)
}
}
return grants, nil
}
func (m *MockGrantedPermissionRepository) GetGrantedPermissionScopes(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID) ([]string, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var scopes []string
for _, grant := range m.grants {
if grant.TokenType == tokenType && grant.TokenID == tokenID && !grant.Revoked {
scopes = append(scopes, grant.Scope)
}
}
return scopes, nil
}
func (m *MockGrantedPermissionRepository) RevokePermission(ctx context.Context, grantID uuid.UUID, revokedBy string) error {
m.mu.Lock()
defer m.mu.Unlock()
grant, exists := m.grants[grantID]
if !exists {
return fmt.Errorf("granted permission with ID '%s' not found", grantID)
}
grant.Revoked = true
return nil
}
func (m *MockGrantedPermissionRepository) RevokeAllPermissions(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, revokedBy string) error {
m.mu.Lock()
defer m.mu.Unlock()
for _, grant := range m.grants {
if grant.TokenType == tokenType && grant.TokenID == tokenID {
grant.Revoked = true
}
}
return nil
}
func (m *MockGrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, grant := range m.grants {
if grant.TokenType == tokenType && grant.TokenID == tokenID && grant.Scope == scope && !grant.Revoked {
return true, nil
}
}
return false, nil
}
func (m *MockGrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]bool)
for _, scope := range scopes {
result[scope] = false
for _, grant := range m.grants {
if grant.TokenType == tokenType && grant.TokenID == tokenID && grant.Scope == scope && !grant.Revoked {
result[scope] = true
break
}
}
}
return result, nil
}
// MockAuditRepository implements AuditRepository for testing
type MockAuditRepository struct {
mu sync.RWMutex
events []*audit.AuditEvent
}
func NewMockAuditRepository() repository.AuditRepository {
return &MockAuditRepository{
events: make([]*audit.AuditEvent, 0),
}
}
func (m *MockAuditRepository) Create(ctx context.Context, event *audit.AuditEvent) error {
m.mu.Lock()
defer m.mu.Unlock()
if event.ID == uuid.Nil {
event.ID = uuid.New()
}
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
m.events = append(m.events, event)
return nil
}
func (m *MockAuditRepository) Query(ctx context.Context, filter *audit.AuditFilter) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var result []*audit.AuditEvent
for _, event := range m.events {
// Simple filtering logic for testing
if len(filter.EventTypes) > 0 {
found := false
for _, t := range filter.EventTypes {
if event.Type == t {
found = true
break
}
}
if !found {
continue
}
}
if filter.ActorID != "" && event.ActorID != filter.ActorID {
continue
}
if filter.ResourceID != "" && event.ResourceID != filter.ResourceID {
continue
}
if filter.ResourceType != "" && event.ResourceType != filter.ResourceType {
continue
}
result = append(result, event)
}
// Apply pagination
if filter.Offset >= len(result) {
return []*audit.AuditEvent{}, nil
}
end := filter.Offset + filter.Limit
if end > len(result) {
end = len(result)
}
return result[filter.Offset:end], nil
}
func (m *MockAuditRepository) GetStats(ctx context.Context, filter *audit.AuditStatsFilter) (*audit.AuditStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
stats := &audit.AuditStats{
TotalEvents: len(m.events),
ByType: make(map[audit.EventType]int),
BySeverity: make(map[audit.EventSeverity]int),
ByStatus: make(map[audit.EventStatus]int),
}
for _, event := range m.events {
stats.ByType[event.Type]++
stats.BySeverity[event.Severity]++
stats.ByStatus[event.Status]++
}
return stats, nil
}
func (m *MockAuditRepository) DeleteOldEvents(ctx context.Context, olderThan time.Time) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
var kept []*audit.AuditEvent
deleted := 0
for _, event := range m.events {
if event.Timestamp.Before(olderThan) {
deleted++
} else {
kept = append(kept, event)
}
}
m.events = kept
return deleted, nil
}
func (m *MockAuditRepository) GetByID(ctx context.Context, eventID uuid.UUID) (*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, event := range m.events {
if event.ID == eventID {
return event, nil
}
}
return nil, fmt.Errorf("audit event with ID '%s' not found", eventID)
}
func (m *MockAuditRepository) GetByRequestID(ctx context.Context, requestID string) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var result []*audit.AuditEvent
for _, event := range m.events {
if event.RequestID == requestID {
result = append(result, event)
}
}
return result, nil
}
func (m *MockAuditRepository) GetBySession(ctx context.Context, sessionID string) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var result []*audit.AuditEvent
for _, event := range m.events {
if event.SessionID == sessionID {
result = append(result, event)
}
}
return result, nil
}
func (m *MockAuditRepository) GetByActor(ctx context.Context, actorID string, limit, offset int) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var matching []*audit.AuditEvent
for _, event := range m.events {
if event.ActorID == actorID {
matching = append(matching, event)
}
}
if offset >= len(matching) {
return []*audit.AuditEvent{}, nil
}
end := offset + limit
if end > len(matching) {
end = len(matching)
}
return matching[offset:end], nil
}
func (m *MockAuditRepository) GetByResource(ctx context.Context, resourceType, resourceID string, limit, offset int) ([]*audit.AuditEvent, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var matching []*audit.AuditEvent
for _, event := range m.events {
if event.ResourceType == resourceType && event.ResourceID == resourceID {
matching = append(matching, event)
}
}
if offset >= len(matching) {
return []*audit.AuditEvent{}, nil
}
end := offset + limit
if end > len(matching) {
end = len(matching)
}
return matching[offset:end], nil
}