Files
skybridge/test/token_repository_test.go
2025-08-22 18:57:40 -04:00

706 lines
17 KiB
Go

package test
import (
"context"
"database/sql"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
"github.com/kms/api-key-service/internal/repository/postgres"
)
// SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing
type SQLMockDatabaseProvider struct {
db *sql.DB
}
func (m *SQLMockDatabaseProvider) GetDB() interface{} {
return m.db
}
func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error {
return m.db.PingContext(ctx)
}
func (m *SQLMockDatabaseProvider) Close() error {
return m.db.Close()
}
func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &SQLMockTransactionProvider{tx: tx}, nil
}
func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error {
return nil
}
// SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing
type SQLMockTransactionProvider struct {
tx *sql.Tx
}
func (m *SQLMockTransactionProvider) Commit() error {
return m.tx.Commit()
}
func (m *SQLMockTransactionProvider) Rollback() error {
return m.tx.Rollback()
}
func (m *SQLMockTransactionProvider) GetTx() interface{} {
return m.tx
}
func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
mockDB := &SQLMockDatabaseProvider{db: db}
repo := postgres.NewStaticTokenRepository(mockDB)
cleanup := func() {
db.Close()
}
return repo.(*postgres.StaticTokenRepository), mock, cleanup
}
func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
if err != nil {
b.Fatal(err)
}
mockDB := &SQLMockDatabaseProvider{db: db}
repo := postgres.NewStaticTokenRepository(mockDB)
cleanup := func() {
db.Close()
}
return repo.(*postgres.StaticTokenRepository), mock, cleanup
}
func TestStaticTokenRepository_Create(t *testing.T) {
tests := []struct {
name string
token *domain.StaticToken
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
}{
{
name: "successful creation",
token: &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: "hmac",
},
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
},
expectError: false,
},
{
name: "database error",
token: &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: "hmac",
},
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to create static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
err := repo.Create(ctx, tt.token)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
assert.NotZero(t, tt.token.CreatedAt)
assert.NotZero(t, tt.token.UpdatedAt)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByID(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedToken *domain.StaticToken
}{
{
name: "successful retrieval",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "individual", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
},
expectError: false,
expectedToken: &domain.StaticToken{
ID: tokenID,
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: string(domain.TokenTypeUser),
CreatedAt: now,
UpdatedAt: now,
},
},
{
name: "token not found",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrNoRows)
},
expectError: true,
errorMsg: "not found",
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to get static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
token, err := repo.GetByID(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, token)
} else {
assert.NoError(t, err)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedToken.ID, token.ID)
assert.Equal(t, tt.expectedToken.AppID, token.AppID)
assert.Equal(t, tt.expectedToken.Owner, token.Owner)
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
assert.Equal(t, tt.expectedToken.Type, token.Type)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByKeyHash(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
keyHash := "test-hash"
tests := []struct {
name string
keyHash string
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedToken *domain.StaticToken
}{
{
name: "successful retrieval",
keyHash: keyHash,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "individual", "test-user", "test-owner",
keyHash, "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
WithArgs(keyHash).
WillReturnRows(rows)
},
expectError: false,
expectedToken: &domain.StaticToken{
ID: tokenID,
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: keyHash,
Type: string(domain.TokenTypeUser),
CreatedAt: now,
UpdatedAt: now,
},
},
{
name: "token not found",
keyHash: keyHash,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
WithArgs(keyHash).
WillReturnError(sql.ErrNoRows)
},
expectError: true,
errorMsg: "not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
token, err := repo.GetByKeyHash(ctx, tt.keyHash)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, token)
} else {
assert.NoError(t, err)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByAppID(t *testing.T) {
tokenID1 := uuid.New()
tokenID2 := uuid.New()
now := time.Now()
appID := "test-app"
tests := []struct {
name string
appID string
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedCount int
}{
{
name: "successful retrieval with multiple tokens",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID1, appID, "user", "test-user1", "test-owner1",
"test-hash1", "user", now, now,
).AddRow(
tokenID2, appID, "user", "test-user2", "test-owner2",
"test-hash2", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 2,
},
{
name: "no tokens found",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
})
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 0,
},
{
name: "database error",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to query static tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
tokens, err := repo.GetByAppID(ctx, tt.appID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.Len(t, tokens, tt.expectedCount)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_List(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
tests := []struct {
name string
limit int
offset int
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedCount int
}{
{
name: "successful list with pagination",
limit: 10,
offset: 0,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "user", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
WithArgs(10, 0).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 1,
},
{
name: "database error",
limit: 10,
offset: 0,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
WithArgs(10, 0).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to query static tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
tokens, err := repo.List(ctx, tt.limit, tt.offset)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.Len(t, tokens, tt.expectedCount)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_Delete(t *testing.T) {
tokenID := uuid.New()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
}{
{
name: "successful deletion",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnResult(sqlmock.NewResult(0, 1))
},
expectError: false,
},
{
name: "token not found",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnResult(sqlmock.NewResult(0, 0))
},
expectError: true,
errorMsg: "not found",
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to delete static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
err := repo.Delete(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_Exists(t *testing.T) {
tokenID := uuid.New()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedExists bool
}{
{
name: "token exists",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"exists"}).AddRow(1)
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
},
expectError: false,
expectedExists: true,
},
{
name: "token does not exist",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrNoRows)
},
expectError: false,
expectedExists: false,
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to check static token existence",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
exists, err := repo.Exists(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedExists, exists)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
// Benchmark tests for repository operations
func BenchmarkStaticTokenRepository_Create(b *testing.B) {
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
defer cleanup()
token := &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: string(domain.TokenTypeUser),
}
// Setup mock expectations for all iterations
for i := 0; i < b.N; i++ {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
token.ID = uuid.New() // Generate new ID for each iteration
err := repo.Create(ctx, token)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkStaticTokenRepository_GetByID(b *testing.B) {
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
defer cleanup()
tokenID := uuid.New()
now := time.Now()
// Setup mock expectations for all iterations
for i := 0; i < b.N; i++ {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "user", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := repo.GetByID(ctx, tokenID)
if err != nil {
b.Fatal(err)
}
}
}