package test import ( "context" "database/sql" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/RyanCopley/skybridge/kms/internal/domain" "github.com/RyanCopley/skybridge/kms/internal/repository" "github.com/RyanCopley/skybridge/kms/internal/repository/postgres" ) // SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing type SQLMockDatabaseProvider struct { db *sql.DB } func (m *SQLMockDatabaseProvider) GetDB() interface{} { return m.db } func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error { return m.db.PingContext(ctx) } func (m *SQLMockDatabaseProvider) Close() error { return m.db.Close() } func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) { tx, err := m.db.BeginTx(ctx, nil) if err != nil { return nil, err } return &SQLMockTransactionProvider{tx: tx}, nil } func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error { return nil } // SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing type SQLMockTransactionProvider struct { tx *sql.Tx } func (m *SQLMockTransactionProvider) Commit() error { return m.tx.Commit() } func (m *SQLMockTransactionProvider) Rollback() error { return m.tx.Rollback() } func (m *SQLMockTransactionProvider) GetTx() interface{} { return m.tx } func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) { db, mock, err := sqlmock.New() require.NoError(t, err) mockDB := &SQLMockDatabaseProvider{db: db} repo := postgres.NewStaticTokenRepository(mockDB) cleanup := func() { db.Close() } return repo.(*postgres.StaticTokenRepository), mock, cleanup } func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) { db, mock, err := sqlmock.New() if err != nil { b.Fatal(err) } mockDB := &SQLMockDatabaseProvider{db: db} repo := postgres.NewStaticTokenRepository(mockDB) cleanup := func() { db.Close() } return repo.(*postgres.StaticTokenRepository), mock, cleanup } func TestStaticTokenRepository_Create(t *testing.T) { tests := []struct { name string token *domain.StaticToken setupMock func(sqlmock.Sqlmock) expectError bool errorMsg string }{ { name: "successful creation", token: &domain.StaticToken{ ID: uuid.New(), AppID: "test-app", Owner: domain.Owner{ Type: domain.OwnerTypeIndividual, Name: "test-user", Owner: "test-owner", }, KeyHash: "test-hash", Type: "hmac", }, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectExec(`INSERT INTO static_tokens`). WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) }, expectError: false, }, { name: "database error", token: &domain.StaticToken{ ID: uuid.New(), AppID: "test-app", Owner: domain.Owner{ Type: domain.OwnerTypeIndividual, Name: "test-user", Owner: "test-owner", }, KeyHash: "test-hash", Type: "hmac", }, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectExec(`INSERT INTO static_tokens`). WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnError(sql.ErrConnDone) }, expectError: true, errorMsg: "failed to create static token", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo, mock, cleanup := setupTokenRepositoryTest(t) defer cleanup() tt.setupMock(mock) ctx := context.Background() err := repo.Create(ctx, tt.token) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) } else { assert.NoError(t, err) assert.NotZero(t, tt.token.CreatedAt) assert.NotZero(t, tt.token.UpdatedAt) } assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestStaticTokenRepository_GetByID(t *testing.T) { tokenID := uuid.New() now := time.Now() tests := []struct { name string tokenID uuid.UUID setupMock func(sqlmock.Sqlmock) expectError bool errorMsg string expectedToken *domain.StaticToken }{ { name: "successful retrieval", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{ "id", "app_id", "owner_type", "owner_name", "owner_owner", "key_hash", "type", "created_at", "updated_at", }).AddRow( tokenID, "test-app", "individual", "test-user", "test-owner", "test-hash", "user", now, now, ) mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnRows(rows) }, expectError: false, expectedToken: &domain.StaticToken{ ID: tokenID, AppID: "test-app", Owner: domain.Owner{ Type: domain.OwnerTypeIndividual, Name: "test-user", Owner: "test-owner", }, KeyHash: "test-hash", Type: string(domain.TokenTypeUser), CreatedAt: now, UpdatedAt: now, }, }, { name: "token not found", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnError(sql.ErrNoRows) }, expectError: true, errorMsg: "not found", }, { name: "database error", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnError(sql.ErrConnDone) }, expectError: true, errorMsg: "failed to get static token", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo, mock, cleanup := setupTokenRepositoryTest(t) defer cleanup() tt.setupMock(mock) ctx := context.Background() token, err := repo.GetByID(ctx, tt.tokenID) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) assert.Nil(t, token) } else { assert.NoError(t, err) assert.NotNil(t, token) assert.Equal(t, tt.expectedToken.ID, token.ID) assert.Equal(t, tt.expectedToken.AppID, token.AppID) assert.Equal(t, tt.expectedToken.Owner, token.Owner) assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash) assert.Equal(t, tt.expectedToken.Type, token.Type) } assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestStaticTokenRepository_GetByKeyHash(t *testing.T) { tokenID := uuid.New() now := time.Now() keyHash := "test-hash" tests := []struct { name string keyHash string setupMock func(sqlmock.Sqlmock) expectError bool errorMsg string expectedToken *domain.StaticToken }{ { name: "successful retrieval", keyHash: keyHash, setupMock: func(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{ "id", "app_id", "owner_type", "owner_name", "owner_owner", "key_hash", "type", "created_at", "updated_at", }).AddRow( tokenID, "test-app", "individual", "test-user", "test-owner", keyHash, "user", now, now, ) mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`). WithArgs(keyHash). WillReturnRows(rows) }, expectError: false, expectedToken: &domain.StaticToken{ ID: tokenID, AppID: "test-app", Owner: domain.Owner{ Type: domain.OwnerTypeIndividual, Name: "test-user", Owner: "test-owner", }, KeyHash: keyHash, Type: string(domain.TokenTypeUser), CreatedAt: now, UpdatedAt: now, }, }, { name: "token not found", keyHash: keyHash, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`). WithArgs(keyHash). WillReturnError(sql.ErrNoRows) }, expectError: true, errorMsg: "not found", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo, mock, cleanup := setupTokenRepositoryTest(t) defer cleanup() tt.setupMock(mock) ctx := context.Background() token, err := repo.GetByKeyHash(ctx, tt.keyHash) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) assert.Nil(t, token) } else { assert.NoError(t, err) assert.NotNil(t, token) assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash) } assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestStaticTokenRepository_GetByAppID(t *testing.T) { tokenID1 := uuid.New() tokenID2 := uuid.New() now := time.Now() appID := "test-app" tests := []struct { name string appID string setupMock func(sqlmock.Sqlmock) expectError bool errorMsg string expectedCount int }{ { name: "successful retrieval with multiple tokens", appID: appID, setupMock: func(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{ "id", "app_id", "owner_type", "owner_name", "owner_owner", "key_hash", "type", "created_at", "updated_at", }).AddRow( tokenID1, appID, "user", "test-user1", "test-owner1", "test-hash1", "user", now, now, ).AddRow( tokenID2, appID, "user", "test-user2", "test-owner2", "test-hash2", "user", now, now, ) mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`). WithArgs(appID). WillReturnRows(rows) }, expectError: false, expectedCount: 2, }, { name: "no tokens found", appID: appID, setupMock: func(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{ "id", "app_id", "owner_type", "owner_name", "owner_owner", "key_hash", "type", "created_at", "updated_at", }) mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`). WithArgs(appID). WillReturnRows(rows) }, expectError: false, expectedCount: 0, }, { name: "database error", appID: appID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`). WithArgs(appID). WillReturnError(sql.ErrConnDone) }, expectError: true, errorMsg: "failed to query static tokens", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo, mock, cleanup := setupTokenRepositoryTest(t) defer cleanup() tt.setupMock(mock) ctx := context.Background() tokens, err := repo.GetByAppID(ctx, tt.appID) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) assert.Nil(t, tokens) } else { assert.NoError(t, err) assert.Len(t, tokens, tt.expectedCount) } assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestStaticTokenRepository_List(t *testing.T) { tokenID := uuid.New() now := time.Now() tests := []struct { name string limit int offset int setupMock func(sqlmock.Sqlmock) expectError bool errorMsg string expectedCount int }{ { name: "successful list with pagination", limit: 10, offset: 0, setupMock: func(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{ "id", "app_id", "owner_type", "owner_name", "owner_owner", "key_hash", "type", "created_at", "updated_at", }).AddRow( tokenID, "test-app", "user", "test-user", "test-owner", "test-hash", "user", now, now, ) mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`). WithArgs(10, 0). WillReturnRows(rows) }, expectError: false, expectedCount: 1, }, { name: "database error", limit: 10, offset: 0, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`). WithArgs(10, 0). WillReturnError(sql.ErrConnDone) }, expectError: true, errorMsg: "failed to query static tokens", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo, mock, cleanup := setupTokenRepositoryTest(t) defer cleanup() tt.setupMock(mock) ctx := context.Background() tokens, err := repo.List(ctx, tt.limit, tt.offset) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) assert.Nil(t, tokens) } else { assert.NoError(t, err) assert.Len(t, tokens, tt.expectedCount) } assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestStaticTokenRepository_Delete(t *testing.T) { tokenID := uuid.New() tests := []struct { name string tokenID uuid.UUID setupMock func(sqlmock.Sqlmock) expectError bool errorMsg string }{ { name: "successful deletion", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnResult(sqlmock.NewResult(0, 1)) }, expectError: false, }, { name: "token not found", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnResult(sqlmock.NewResult(0, 0)) }, expectError: true, errorMsg: "not found", }, { name: "database error", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnError(sql.ErrConnDone) }, expectError: true, errorMsg: "failed to delete static token", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo, mock, cleanup := setupTokenRepositoryTest(t) defer cleanup() tt.setupMock(mock) ctx := context.Background() err := repo.Delete(ctx, tt.tokenID) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) } else { assert.NoError(t, err) } assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestStaticTokenRepository_Exists(t *testing.T) { tokenID := uuid.New() tests := []struct { name string tokenID uuid.UUID setupMock func(sqlmock.Sqlmock) expectError bool errorMsg string expectedExists bool }{ { name: "token exists", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"exists"}).AddRow(1) mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnRows(rows) }, expectError: false, expectedExists: true, }, { name: "token does not exist", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnError(sql.ErrNoRows) }, expectError: false, expectedExists: false, }, { name: "database error", tokenID: tokenID, setupMock: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnError(sql.ErrConnDone) }, expectError: true, errorMsg: "failed to check static token existence", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo, mock, cleanup := setupTokenRepositoryTest(t) defer cleanup() tt.setupMock(mock) ctx := context.Background() exists, err := repo.Exists(ctx, tt.tokenID) if tt.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) } else { assert.NoError(t, err) assert.Equal(t, tt.expectedExists, exists) } assert.NoError(t, mock.ExpectationsWereMet()) }) } } // Benchmark tests for repository operations func BenchmarkStaticTokenRepository_Create(b *testing.B) { repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b) defer cleanup() token := &domain.StaticToken{ ID: uuid.New(), AppID: "test-app", Owner: domain.Owner{ Type: domain.OwnerTypeIndividual, Name: "test-user", Owner: "test-owner", }, KeyHash: "test-hash", Type: string(domain.TokenTypeUser), } // Setup mock expectations for all iterations for i := 0; i < b.N; i++ { mock.ExpectExec(`INSERT INTO static_tokens`). WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) } ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { token.ID = uuid.New() // Generate new ID for each iteration err := repo.Create(ctx, token) if err != nil { b.Fatal(err) } } } func BenchmarkStaticTokenRepository_GetByID(b *testing.B) { repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b) defer cleanup() tokenID := uuid.New() now := time.Now() // Setup mock expectations for all iterations for i := 0; i < b.N; i++ { rows := sqlmock.NewRows([]string{ "id", "app_id", "owner_type", "owner_name", "owner_owner", "key_hash", "type", "created_at", "updated_at", }).AddRow( tokenID, "test-app", "user", "test-user", "test-owner", "test-hash", "user", now, now, ) mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`). WithArgs(tokenID). WillReturnRows(rows) } ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { _, err := repo.GetByID(ctx, tokenID) if err != nil { b.Fatal(err) } } }