This commit is contained in:
2025-08-22 18:57:40 -04:00
parent d648a55c0c
commit df567983c1
20 changed files with 4519 additions and 8 deletions

532
test/saml_test.go Normal file
View File

@ -0,0 +1,532 @@
package test
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/domain"
)
// mockSAMLMetadata returns a mock SAML IdP metadata XML
func mockSAMLMetadata() string {
return `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="https://idp.example.com">
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:KeyDescriptor use="signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>MIICertificateData</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>`
}
// mockSAMLResponse returns a mock SAML response XML with current timestamps
func mockSAMLResponse() string {
now := time.Now().UTC()
issueInstant := now.Format(time.RFC3339)
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="_response_id" Version="2.0" IssueInstant="%s"
Destination="https://sp.example.com/acs" InResponseTo="_request_id">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<samlp:Status>
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
</samlp:Status>
<saml:Assertion ID="_assertion_id" Version="2.0" IssueInstant="%s">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<saml:Subject>
<saml:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress">user@example.com</saml:NameID>
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
<saml:SubjectConfirmationData InResponseTo="_request_id" NotOnOrAfter="%s" Recipient="https://sp.example.com/acs"/>
</saml:SubjectConfirmation>
</saml:Subject>
<saml:Conditions NotBefore="%s" NotOnOrAfter="%s">
<saml:AudienceRestriction>
<saml:Audience>https://sp.example.com</saml:Audience>
</saml:AudienceRestriction>
</saml:Conditions>
<saml:AttributeStatement>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress">
<saml:AttributeValue>user@example.com</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name">
<saml:AttributeValue>Test User</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname">
<saml:AttributeValue>Test</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname">
<saml:AttributeValue>User</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="http://schemas.microsoft.com/ws/2008/06/identity/claims/role">
<saml:AttributeValue>admin,user</saml:AttributeValue>
</saml:Attribute>
</saml:AttributeStatement>
<saml:AuthnStatement AuthnInstant="%s" SessionIndex="_session_index">
<saml:AuthnContext>
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
</saml:AuthnContext>
</saml:AuthnStatement>
</saml:Assertion>
</samlp:Response>`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
}
func TestSAMLProvider_GetMetadata(t *testing.T) {
tests := []struct {
name string
metadataURL string
serverResponse string
serverStatus int
expectError bool
errorContains string
}{
{
name: "successful metadata fetch",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: mockSAMLMetadata(),
serverStatus: http.StatusOK,
expectError: false,
},
{
name: "missing metadata URL",
metadataURL: "",
expectError: true,
errorContains: "SAML_IDP_METADATA_URL not configured",
},
{
name: "server error",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverStatus: http.StatusInternalServerError,
expectError: true,
errorContains: "returned status 500",
},
{
name: "invalid XML",
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
serverResponse: "invalid xml",
serverStatus: http.StatusOK,
expectError: true,
errorContains: "Failed to parse SAML metadata",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server
var server *httptest.Server
if tt.metadataURL != "" && tt.serverStatus > 0 {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.serverStatus)
if tt.serverResponse != "" {
w.Write([]byte(tt.serverResponse))
}
}))
defer server.Close()
tt.metadataURL = server.URL
}
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GetMetadata
ctx := context.Background()
metadata, err := provider.GetMetadata(ctx)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, metadata)
} else {
assert.NoError(t, err)
assert.NotNil(t, metadata)
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
}
})
}
}
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
relayState string
expectError bool
errorContains string
}{
{
name: "successful auth request generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
relayState: "test-relay-state",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateAuthRequest
ctx := context.Background()
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, authURL)
assert.Empty(t, requestID)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, authURL)
assert.NotEmpty(t, requestID)
assert.Contains(t, authURL, "https://idp.example.com/sso")
assert.Contains(t, authURL, "SAMLRequest=")
if tt.relayState != "" {
assert.Contains(t, authURL, "RelayState="+tt.relayState)
}
}
})
}
}
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
tests := []struct {
name string
samlResponse string
expectedRequestID string
spEntityID string
expectError bool
errorContains string
expectedUserID string
expectedEmail string
expectedName string
expectedRoles []string
}{
{
name: "successful SAML response processing",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
expectedRequestID: "_request_id",
spEntityID: "https://sp.example.com",
expectedUserID: "user@example.com",
expectedEmail: "user@example.com",
expectedName: "Test User",
expectedRoles: []string{"admin", "user"},
},
{
name: "invalid base64 encoding",
samlResponse: "invalid-base64",
expectError: true,
errorContains: "Failed to decode SAML response",
},
{
name: "invalid XML",
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
expectError: true,
errorContains: "Failed to parse SAML response",
},
{
name: "audience mismatch",
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
spEntityID: "https://wrong-sp.example.com",
expectError: true,
errorContains: "audience mismatch",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
assert.Equal(t, tt.expectedUserID, authContext.UserID)
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
// Check claims
if tt.expectedEmail != "" {
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
}
if tt.expectedName != "" {
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
}
// Check permissions/roles
if len(tt.expectedRoles) > 0 {
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
}
}
})
}
}
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
tests := []struct {
name string
spEntityID string
acsURL string
expectError bool
errorContains string
}{
{
name: "successful SP metadata generation",
spEntityID: "https://sp.example.com",
acsURL: "https://sp.example.com/acs",
},
{
name: "missing SP entity ID",
spEntityID: "",
acsURL: "https://sp.example.com/acs",
expectError: true,
errorContains: "SAML_SP_ENTITY_ID not configured",
},
{
name: "missing ACS URL",
spEntityID: "https://sp.example.com",
acsURL: "",
expectError: true,
errorContains: "SAML_SP_ACS_URL not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Test GenerateServiceProviderMetadata
metadata, err := provider.GenerateServiceProviderMetadata()
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Empty(t, metadata)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, metadata)
assert.Contains(t, metadata, tt.spEntityID)
assert.Contains(t, metadata, tt.acsURL)
assert.Contains(t, metadata, "EntityDescriptor")
assert.Contains(t, metadata, "SPSSODescriptor")
}
})
}
}
// Benchmark tests for SAML operations
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
// Prepare SAML response
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
// Create mock HTTP server for metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockSAMLMetadata()))
}))
defer server.Close()
// Create config
cfg := NewTestConfig()
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
// Create SAML provider
logger := zaptest.NewLogger(b)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(b, err)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
if err != nil {
b.Fatal(err)
}
}
}
// Test helper functions
func TestSAMLResponseValidation(t *testing.T) {
// Test various SAML response validation scenarios
tests := []struct {
name string
modifyXML func(string) string
expectError bool
errorContains string
}{
{
name: "expired assertion",
modifyXML: func(xml string) string {
// Replace all NotOnOrAfter timestamps with past time
pastTime := "2020-01-01T13:00:00Z"
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
},
expectError: true,
errorContains: "assertion has expired",
},
{
name: "assertion not yet valid",
modifyXML: func(xml string) string {
// Replace all NotBefore timestamps with future time
futureTime := "2030-01-01T11:55:00Z"
re := regexp.MustCompile(`NotBefore="[^"]*"`)
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
},
expectError: true,
errorContains: "assertion not yet valid",
},
{
name: "failed status",
modifyXML: func(xml string) string {
return strings.ReplaceAll(xml,
"urn:oasis:names:tc:SAML:2.0:status:Success",
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
},
expectError: true,
errorContains: "SAML authentication failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config
cfg := NewTestConfig()
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
// Create SAML provider
logger := zaptest.NewLogger(t)
provider, err := auth.NewSAMLProvider(cfg, logger)
require.NoError(t, err)
// Modify SAML response
modifiedXML := tt.modifyXML(mockSAMLResponse())
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
// Test ProcessSAMLResponse
ctx := context.Background()
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
if tt.expectError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
}
})
}
}

View File

@ -0,0 +1,705 @@
package test
import (
"context"
"database/sql"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
"github.com/kms/api-key-service/internal/repository/postgres"
)
// SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing
type SQLMockDatabaseProvider struct {
db *sql.DB
}
func (m *SQLMockDatabaseProvider) GetDB() interface{} {
return m.db
}
func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error {
return m.db.PingContext(ctx)
}
func (m *SQLMockDatabaseProvider) Close() error {
return m.db.Close()
}
func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &SQLMockTransactionProvider{tx: tx}, nil
}
func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error {
return nil
}
// SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing
type SQLMockTransactionProvider struct {
tx *sql.Tx
}
func (m *SQLMockTransactionProvider) Commit() error {
return m.tx.Commit()
}
func (m *SQLMockTransactionProvider) Rollback() error {
return m.tx.Rollback()
}
func (m *SQLMockTransactionProvider) GetTx() interface{} {
return m.tx
}
func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
mockDB := &SQLMockDatabaseProvider{db: db}
repo := postgres.NewStaticTokenRepository(mockDB)
cleanup := func() {
db.Close()
}
return repo.(*postgres.StaticTokenRepository), mock, cleanup
}
func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
if err != nil {
b.Fatal(err)
}
mockDB := &SQLMockDatabaseProvider{db: db}
repo := postgres.NewStaticTokenRepository(mockDB)
cleanup := func() {
db.Close()
}
return repo.(*postgres.StaticTokenRepository), mock, cleanup
}
func TestStaticTokenRepository_Create(t *testing.T) {
tests := []struct {
name string
token *domain.StaticToken
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
}{
{
name: "successful creation",
token: &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: "hmac",
},
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
},
expectError: false,
},
{
name: "database error",
token: &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: "hmac",
},
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to create static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
err := repo.Create(ctx, tt.token)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
assert.NotZero(t, tt.token.CreatedAt)
assert.NotZero(t, tt.token.UpdatedAt)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByID(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedToken *domain.StaticToken
}{
{
name: "successful retrieval",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "individual", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
},
expectError: false,
expectedToken: &domain.StaticToken{
ID: tokenID,
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: string(domain.TokenTypeUser),
CreatedAt: now,
UpdatedAt: now,
},
},
{
name: "token not found",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrNoRows)
},
expectError: true,
errorMsg: "not found",
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to get static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
token, err := repo.GetByID(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, token)
} else {
assert.NoError(t, err)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedToken.ID, token.ID)
assert.Equal(t, tt.expectedToken.AppID, token.AppID)
assert.Equal(t, tt.expectedToken.Owner, token.Owner)
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
assert.Equal(t, tt.expectedToken.Type, token.Type)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByKeyHash(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
keyHash := "test-hash"
tests := []struct {
name string
keyHash string
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedToken *domain.StaticToken
}{
{
name: "successful retrieval",
keyHash: keyHash,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "individual", "test-user", "test-owner",
keyHash, "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
WithArgs(keyHash).
WillReturnRows(rows)
},
expectError: false,
expectedToken: &domain.StaticToken{
ID: tokenID,
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: keyHash,
Type: string(domain.TokenTypeUser),
CreatedAt: now,
UpdatedAt: now,
},
},
{
name: "token not found",
keyHash: keyHash,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
WithArgs(keyHash).
WillReturnError(sql.ErrNoRows)
},
expectError: true,
errorMsg: "not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
token, err := repo.GetByKeyHash(ctx, tt.keyHash)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, token)
} else {
assert.NoError(t, err)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_GetByAppID(t *testing.T) {
tokenID1 := uuid.New()
tokenID2 := uuid.New()
now := time.Now()
appID := "test-app"
tests := []struct {
name string
appID string
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedCount int
}{
{
name: "successful retrieval with multiple tokens",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID1, appID, "user", "test-user1", "test-owner1",
"test-hash1", "user", now, now,
).AddRow(
tokenID2, appID, "user", "test-user2", "test-owner2",
"test-hash2", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 2,
},
{
name: "no tokens found",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
})
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 0,
},
{
name: "database error",
appID: appID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
WithArgs(appID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to query static tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
tokens, err := repo.GetByAppID(ctx, tt.appID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.Len(t, tokens, tt.expectedCount)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_List(t *testing.T) {
tokenID := uuid.New()
now := time.Now()
tests := []struct {
name string
limit int
offset int
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedCount int
}{
{
name: "successful list with pagination",
limit: 10,
offset: 0,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "user", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
WithArgs(10, 0).
WillReturnRows(rows)
},
expectError: false,
expectedCount: 1,
},
{
name: "database error",
limit: 10,
offset: 0,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
WithArgs(10, 0).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to query static tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
tokens, err := repo.List(ctx, tt.limit, tt.offset)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.Len(t, tokens, tt.expectedCount)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_Delete(t *testing.T) {
tokenID := uuid.New()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
}{
{
name: "successful deletion",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnResult(sqlmock.NewResult(0, 1))
},
expectError: false,
},
{
name: "token not found",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnResult(sqlmock.NewResult(0, 0))
},
expectError: true,
errorMsg: "not found",
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to delete static token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
err := repo.Delete(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
func TestStaticTokenRepository_Exists(t *testing.T) {
tokenID := uuid.New()
tests := []struct {
name string
tokenID uuid.UUID
setupMock func(sqlmock.Sqlmock)
expectError bool
errorMsg string
expectedExists bool
}{
{
name: "token exists",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"exists"}).AddRow(1)
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
},
expectError: false,
expectedExists: true,
},
{
name: "token does not exist",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrNoRows)
},
expectError: false,
expectedExists: false,
},
{
name: "database error",
tokenID: tokenID,
setupMock: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnError(sql.ErrConnDone)
},
expectError: true,
errorMsg: "failed to check static token existence",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, mock, cleanup := setupTokenRepositoryTest(t)
defer cleanup()
tt.setupMock(mock)
ctx := context.Background()
exists, err := repo.Exists(ctx, tt.tokenID)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedExists, exists)
}
assert.NoError(t, mock.ExpectationsWereMet())
})
}
}
// Benchmark tests for repository operations
func BenchmarkStaticTokenRepository_Create(b *testing.B) {
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
defer cleanup()
token := &domain.StaticToken{
ID: uuid.New(),
AppID: "test-app",
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: "test-user",
Owner: "test-owner",
},
KeyHash: "test-hash",
Type: string(domain.TokenTypeUser),
}
// Setup mock expectations for all iterations
for i := 0; i < b.N; i++ {
mock.ExpectExec(`INSERT INTO static_tokens`).
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
token.ID = uuid.New() // Generate new ID for each iteration
err := repo.Create(ctx, token)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkStaticTokenRepository_GetByID(b *testing.B) {
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
defer cleanup()
tokenID := uuid.New()
now := time.Now()
// Setup mock expectations for all iterations
for i := 0; i < b.N; i++ {
rows := sqlmock.NewRows([]string{
"id", "app_id", "owner_type", "owner_name", "owner_owner",
"key_hash", "type", "created_at", "updated_at",
}).AddRow(
tokenID, "test-app", "user", "test-user", "test-owner",
"test-hash", "user", now, now,
)
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
WithArgs(tokenID).
WillReturnRows(rows)
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := repo.GetByID(ctx, tokenID)
if err != nil {
b.Fatal(err)
}
}
}