v2
This commit is contained in:
532
test/saml_test.go
Normal file
532
test/saml_test.go
Normal file
@ -0,0 +1,532 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
)
|
||||
|
||||
// mockSAMLMetadata returns a mock SAML IdP metadata XML
|
||||
func mockSAMLMetadata() string {
|
||||
return `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="https://idp.example.com">
|
||||
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<md:KeyDescriptor use="signing">
|
||||
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
|
||||
<ds:X509Data>
|
||||
<ds:X509Certificate>MIICertificateData</ds:X509Certificate>
|
||||
</ds:X509Data>
|
||||
</ds:KeyInfo>
|
||||
</md:KeyDescriptor>
|
||||
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
|
||||
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/slo"/>
|
||||
</md:IDPSSODescriptor>
|
||||
</md:EntityDescriptor>`
|
||||
}
|
||||
|
||||
// mockSAMLResponse returns a mock SAML response XML with current timestamps
|
||||
func mockSAMLResponse() string {
|
||||
now := time.Now().UTC()
|
||||
issueInstant := now.Format(time.RFC3339)
|
||||
notBefore := now.Add(-5 * time.Minute).Format(time.RFC3339)
|
||||
notOnOrAfter := now.Add(60 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
return fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
|
||||
ID="_response_id" Version="2.0" IssueInstant="%s"
|
||||
Destination="https://sp.example.com/acs" InResponseTo="_request_id">
|
||||
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
||||
<samlp:Status>
|
||||
<samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
|
||||
</samlp:Status>
|
||||
<saml:Assertion ID="_assertion_id" Version="2.0" IssueInstant="%s">
|
||||
<saml:Issuer>https://idp.example.com</saml:Issuer>
|
||||
<saml:Subject>
|
||||
<saml:NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:emailAddress">user@example.com</saml:NameID>
|
||||
<saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
|
||||
<saml:SubjectConfirmationData InResponseTo="_request_id" NotOnOrAfter="%s" Recipient="https://sp.example.com/acs"/>
|
||||
</saml:SubjectConfirmation>
|
||||
</saml:Subject>
|
||||
<saml:Conditions NotBefore="%s" NotOnOrAfter="%s">
|
||||
<saml:AudienceRestriction>
|
||||
<saml:Audience>https://sp.example.com</saml:Audience>
|
||||
</saml:AudienceRestriction>
|
||||
</saml:Conditions>
|
||||
<saml:AttributeStatement>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress">
|
||||
<saml:AttributeValue>user@example.com</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name">
|
||||
<saml:AttributeValue>Test User</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname">
|
||||
<saml:AttributeValue>Test</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.xmlsoap.org/ws/2005/05/identity/claims/surname">
|
||||
<saml:AttributeValue>User</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
<saml:Attribute Name="http://schemas.microsoft.com/ws/2008/06/identity/claims/role">
|
||||
<saml:AttributeValue>admin,user</saml:AttributeValue>
|
||||
</saml:Attribute>
|
||||
</saml:AttributeStatement>
|
||||
<saml:AuthnStatement AuthnInstant="%s" SessionIndex="_session_index">
|
||||
<saml:AuthnContext>
|
||||
<saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
|
||||
</saml:AuthnContext>
|
||||
</saml:AuthnStatement>
|
||||
</saml:Assertion>
|
||||
</samlp:Response>`, issueInstant, issueInstant, notOnOrAfter, notBefore, notOnOrAfter, issueInstant)
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GetMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadataURL string
|
||||
serverResponse string
|
||||
serverStatus int
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful metadata fetch",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverResponse: mockSAMLMetadata(),
|
||||
serverStatus: http.StatusOK,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing metadata URL",
|
||||
metadataURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_IDP_METADATA_URL not configured",
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverStatus: http.StatusInternalServerError,
|
||||
expectError: true,
|
||||
errorContains: "returned status 500",
|
||||
},
|
||||
{
|
||||
name: "invalid XML",
|
||||
metadataURL: "https://idp.example.com/.well-known/saml-metadata",
|
||||
serverResponse: "invalid xml",
|
||||
serverStatus: http.StatusOK,
|
||||
expectError: true,
|
||||
errorContains: "Failed to parse SAML metadata",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock HTTP server
|
||||
var server *httptest.Server
|
||||
if tt.metadataURL != "" && tt.serverStatus > 0 {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tt.serverStatus)
|
||||
if tt.serverResponse != "" {
|
||||
w.Write([]byte(tt.serverResponse))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
tt.metadataURL = server.URL
|
||||
}
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = tt.metadataURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GetMetadata
|
||||
ctx := context.Background()
|
||||
metadata, err := provider.GetMetadata(ctx)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, metadata)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, metadata)
|
||||
assert.Equal(t, "https://idp.example.com", metadata.EntityID)
|
||||
assert.NotEmpty(t, metadata.IDPSSODescriptor.SingleSignOnService)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GenerateAuthRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spEntityID string
|
||||
acsURL string
|
||||
relayState string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful auth request generation",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
relayState: "test-relay-state",
|
||||
},
|
||||
{
|
||||
name: "missing SP entity ID",
|
||||
spEntityID: "",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ENTITY_ID not configured",
|
||||
},
|
||||
{
|
||||
name: "missing ACS URL",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ACS_URL not configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock HTTP server for metadata
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockSAMLMetadata()))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GenerateAuthRequest
|
||||
ctx := context.Background()
|
||||
authURL, requestID, err := provider.GenerateAuthRequest(ctx, tt.relayState)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Empty(t, authURL)
|
||||
assert.Empty(t, requestID)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, authURL)
|
||||
assert.NotEmpty(t, requestID)
|
||||
assert.Contains(t, authURL, "https://idp.example.com/sso")
|
||||
assert.Contains(t, authURL, "SAMLRequest=")
|
||||
if tt.relayState != "" {
|
||||
assert.Contains(t, authURL, "RelayState="+tt.relayState)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_ProcessSAMLResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
samlResponse string
|
||||
expectedRequestID string
|
||||
spEntityID string
|
||||
expectError bool
|
||||
errorContains string
|
||||
expectedUserID string
|
||||
expectedEmail string
|
||||
expectedName string
|
||||
expectedRoles []string
|
||||
}{
|
||||
{
|
||||
name: "successful SAML response processing",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
||||
expectedRequestID: "_request_id",
|
||||
spEntityID: "https://sp.example.com",
|
||||
expectedUserID: "user@example.com",
|
||||
expectedEmail: "user@example.com",
|
||||
expectedName: "Test User",
|
||||
expectedRoles: []string{"admin", "user"},
|
||||
},
|
||||
{
|
||||
name: "invalid base64 encoding",
|
||||
samlResponse: "invalid-base64",
|
||||
expectError: true,
|
||||
errorContains: "Failed to decode SAML response",
|
||||
},
|
||||
{
|
||||
name: "invalid XML",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte("invalid xml")),
|
||||
expectError: true,
|
||||
errorContains: "Failed to parse SAML response",
|
||||
},
|
||||
{
|
||||
name: "audience mismatch",
|
||||
samlResponse: base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse())),
|
||||
spEntityID: "https://wrong-sp.example.com",
|
||||
expectError: true,
|
||||
errorContains: "audience mismatch",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test ProcessSAMLResponse
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ProcessSAMLResponse(ctx, tt.samlResponse, tt.expectedRequestID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
assert.Equal(t, tt.expectedUserID, authContext.UserID)
|
||||
assert.Equal(t, domain.TokenTypeUser, authContext.TokenType)
|
||||
|
||||
// Check claims
|
||||
if tt.expectedEmail != "" {
|
||||
assert.Equal(t, tt.expectedEmail, authContext.Claims["email"])
|
||||
}
|
||||
if tt.expectedName != "" {
|
||||
assert.Equal(t, tt.expectedName, authContext.Claims["name"])
|
||||
}
|
||||
|
||||
// Check permissions/roles
|
||||
if len(tt.expectedRoles) > 0 {
|
||||
assert.Equal(t, tt.expectedRoles, authContext.Permissions)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLProvider_GenerateServiceProviderMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spEntityID string
|
||||
acsURL string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful SP metadata generation",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
},
|
||||
{
|
||||
name: "missing SP entity ID",
|
||||
spEntityID: "",
|
||||
acsURL: "https://sp.example.com/acs",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ENTITY_ID not configured",
|
||||
},
|
||||
{
|
||||
name: "missing ACS URL",
|
||||
spEntityID: "https://sp.example.com",
|
||||
acsURL: "",
|
||||
expectError: true,
|
||||
errorContains: "SAML_SP_ACS_URL not configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = tt.spEntityID
|
||||
cfg.values["SAML_SP_ACS_URL"] = tt.acsURL
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GenerateServiceProviderMetadata
|
||||
metadata, err := provider.GenerateServiceProviderMetadata()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Empty(t, metadata)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, metadata)
|
||||
assert.Contains(t, metadata, tt.spEntityID)
|
||||
assert.Contains(t, metadata, tt.acsURL)
|
||||
assert.Contains(t, metadata, "EntityDescriptor")
|
||||
assert.Contains(t, metadata, "SPSSODescriptor")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for SAML operations
|
||||
func BenchmarkSAMLProvider_ProcessSAMLResponse(b *testing.B) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(b)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(b, err)
|
||||
|
||||
// Prepare SAML response
|
||||
samlResponse := base64.StdEncoding.EncodeToString([]byte(mockSAMLResponse()))
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSAMLProvider_GenerateAuthRequest(b *testing.B) {
|
||||
// Create mock HTTP server for metadata
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockSAMLMetadata()))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_IDP_METADATA_URL"] = server.URL
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
cfg.values["SAML_SP_ACS_URL"] = "https://sp.example.com/acs"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(b)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(b, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, err := provider.GenerateAuthRequest(ctx, "test-relay-state")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestSAMLResponseValidation(t *testing.T) {
|
||||
// Test various SAML response validation scenarios
|
||||
tests := []struct {
|
||||
name string
|
||||
modifyXML func(string) string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "expired assertion",
|
||||
modifyXML: func(xml string) string {
|
||||
// Replace all NotOnOrAfter timestamps with past time
|
||||
pastTime := "2020-01-01T13:00:00Z"
|
||||
re := regexp.MustCompile(`NotOnOrAfter="[^"]*"`)
|
||||
return re.ReplaceAllString(xml, `NotOnOrAfter="`+pastTime+`"`)
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "assertion has expired",
|
||||
},
|
||||
{
|
||||
name: "assertion not yet valid",
|
||||
modifyXML: func(xml string) string {
|
||||
// Replace all NotBefore timestamps with future time
|
||||
futureTime := "2030-01-01T11:55:00Z"
|
||||
re := regexp.MustCompile(`NotBefore="[^"]*"`)
|
||||
return re.ReplaceAllString(xml, `NotBefore="`+futureTime+`"`)
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "assertion not yet valid",
|
||||
},
|
||||
{
|
||||
name: "failed status",
|
||||
modifyXML: func(xml string) string {
|
||||
return strings.ReplaceAll(xml,
|
||||
"urn:oasis:names:tc:SAML:2.0:status:Success",
|
||||
"urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "SAML authentication failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config
|
||||
cfg := NewTestConfig()
|
||||
cfg.values["SAML_SP_ENTITY_ID"] = "https://sp.example.com"
|
||||
|
||||
// Create SAML provider
|
||||
logger := zaptest.NewLogger(t)
|
||||
provider, err := auth.NewSAMLProvider(cfg, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify SAML response
|
||||
modifiedXML := tt.modifyXML(mockSAMLResponse())
|
||||
samlResponse := base64.StdEncoding.EncodeToString([]byte(modifiedXML))
|
||||
|
||||
// Test ProcessSAMLResponse
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ProcessSAMLResponse(ctx, samlResponse, "_request_id")
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
705
test/token_repository_test.go
Normal file
705
test/token_repository_test.go
Normal file
@ -0,0 +1,705 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
"github.com/kms/api-key-service/internal/repository/postgres"
|
||||
)
|
||||
|
||||
// SQLMockDatabaseProvider implements repository.DatabaseProvider for SQL testing
|
||||
type SQLMockDatabaseProvider struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) GetDB() interface{} {
|
||||
return m.db
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) Ping(ctx context.Context) error {
|
||||
return m.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) Close() error {
|
||||
return m.db.Close()
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
|
||||
tx, err := m.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SQLMockTransactionProvider{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (m *SQLMockDatabaseProvider) Migrate(ctx context.Context, migrationPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SQLMockTransactionProvider implements repository.TransactionProvider for SQL testing
|
||||
type SQLMockTransactionProvider struct {
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
func (m *SQLMockTransactionProvider) Commit() error {
|
||||
return m.tx.Commit()
|
||||
}
|
||||
|
||||
func (m *SQLMockTransactionProvider) Rollback() error {
|
||||
return m.tx.Rollback()
|
||||
}
|
||||
|
||||
func (m *SQLMockTransactionProvider) GetTx() interface{} {
|
||||
return m.tx
|
||||
}
|
||||
|
||||
func setupTokenRepositoryTest(t *testing.T) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDB := &SQLMockDatabaseProvider{db: db}
|
||||
repo := postgres.NewStaticTokenRepository(mockDB)
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
return repo.(*postgres.StaticTokenRepository), mock, cleanup
|
||||
}
|
||||
|
||||
func setupTokenRepositoryTestBenchmark(b *testing.B) (*postgres.StaticTokenRepository, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
mockDB := &SQLMockDatabaseProvider{db: db}
|
||||
repo := postgres.NewStaticTokenRepository(mockDB)
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
return repo.(*postgres.StaticTokenRepository), mock, cleanup
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_Create(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token *domain.StaticToken
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful creation",
|
||||
token: &domain.StaticToken{
|
||||
ID: uuid.New(),
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: "hmac",
|
||||
},
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`INSERT INTO static_tokens`).
|
||||
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
token: &domain.StaticToken{
|
||||
ID: uuid.New(),
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: "hmac",
|
||||
},
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`INSERT INTO static_tokens`).
|
||||
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "hmac", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create static token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
err := repo.Create(ctx, tt.token)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, tt.token.CreatedAt)
|
||||
assert.NotZero(t, tt.token.UpdatedAt)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_GetByID(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID uuid.UUID
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedToken *domain.StaticToken
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "individual", "test-user", "test-owner",
|
||||
"test-hash", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedToken: &domain.StaticToken{
|
||||
ID: tokenID,
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: string(domain.TokenTypeUser),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "not found",
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to get static token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
token, err := repo.GetByID(ctx, tt.tokenID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, token)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, token)
|
||||
assert.Equal(t, tt.expectedToken.ID, token.ID)
|
||||
assert.Equal(t, tt.expectedToken.AppID, token.AppID)
|
||||
assert.Equal(t, tt.expectedToken.Owner, token.Owner)
|
||||
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
|
||||
assert.Equal(t, tt.expectedToken.Type, token.Type)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_GetByKeyHash(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
keyHash := "test-hash"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyHash string
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedToken *domain.StaticToken
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
keyHash: keyHash,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "individual", "test-user", "test-owner",
|
||||
keyHash, "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
|
||||
WithArgs(keyHash).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedToken: &domain.StaticToken{
|
||||
ID: tokenID,
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: keyHash,
|
||||
Type: string(domain.TokenTypeUser),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
keyHash: keyHash,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE key_hash = \$1`).
|
||||
WithArgs(keyHash).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
token, err := repo.GetByKeyHash(ctx, tt.keyHash)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, token)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, token)
|
||||
assert.Equal(t, tt.expectedToken.KeyHash, token.KeyHash)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_GetByAppID(t *testing.T) {
|
||||
tokenID1 := uuid.New()
|
||||
tokenID2 := uuid.New()
|
||||
now := time.Now()
|
||||
appID := "test-app"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
appID string
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval with multiple tokens",
|
||||
appID: appID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID1, appID, "user", "test-user1", "test-owner1",
|
||||
"test-hash1", "user", now, now,
|
||||
).AddRow(
|
||||
tokenID2, appID, "user", "test-user2", "test-owner2",
|
||||
"test-hash2", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
|
||||
WithArgs(appID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "no tokens found",
|
||||
appID: appID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
})
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
|
||||
WithArgs(appID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
appID: appID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE app_id = \$1 ORDER BY created_at DESC`).
|
||||
WithArgs(appID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to query static tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
tokens, err := repo.GetByAppID(ctx, tt.appID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, tokens)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, tokens, tt.expectedCount)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_List(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
limit int
|
||||
offset int
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "successful list with pagination",
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "user", "test-user", "test-owner",
|
||||
"test-hash", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
|
||||
WithArgs(10, 0).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens ORDER BY created_at DESC LIMIT \$1 OFFSET \$2`).
|
||||
WithArgs(10, 0).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to query static tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
tokens, err := repo.List(ctx, tt.limit, tt.offset)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, tokens)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, tokens, tt.expectedCount)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_Delete(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID uuid.UUID
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful deletion",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "not found",
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`DELETE FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to delete static token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
err := repo.Delete(ctx, tt.tokenID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenRepository_Exists(t *testing.T) {
|
||||
tokenID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenID uuid.UUID
|
||||
setupMock func(sqlmock.Sqlmock)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
expectedExists bool
|
||||
}{
|
||||
{
|
||||
name: "token exists",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"exists"}).AddRow(1)
|
||||
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnRows(rows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedExists: true,
|
||||
},
|
||||
{
|
||||
name: "token does not exist",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
},
|
||||
expectError: false,
|
||||
expectedExists: false,
|
||||
},
|
||||
{
|
||||
name: "database error",
|
||||
tokenID: tokenID,
|
||||
setupMock: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery(`SELECT 1 FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to check static token existence",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTest(t)
|
||||
defer cleanup()
|
||||
|
||||
tt.setupMock(mock)
|
||||
|
||||
ctx := context.Background()
|
||||
exists, err := repo.Exists(ctx, tt.tokenID)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedExists, exists)
|
||||
}
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for repository operations
|
||||
func BenchmarkStaticTokenRepository_Create(b *testing.B) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
|
||||
defer cleanup()
|
||||
|
||||
token := &domain.StaticToken{
|
||||
ID: uuid.New(),
|
||||
AppID: "test-app",
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: "test-user",
|
||||
Owner: "test-owner",
|
||||
},
|
||||
KeyHash: "test-hash",
|
||||
Type: string(domain.TokenTypeUser),
|
||||
}
|
||||
|
||||
// Setup mock expectations for all iterations
|
||||
for i := 0; i < b.N; i++ {
|
||||
mock.ExpectExec(`INSERT INTO static_tokens`).
|
||||
WithArgs(sqlmock.AnyArg(), "test-app", "individual", "test-user", "test-owner", "test-hash", "user", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
token.ID = uuid.New() // Generate new ID for each iteration
|
||||
err := repo.Create(ctx, token)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStaticTokenRepository_GetByID(b *testing.B) {
|
||||
repo, mock, cleanup := setupTokenRepositoryTestBenchmark(b)
|
||||
defer cleanup()
|
||||
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
// Setup mock expectations for all iterations
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "app_id", "owner_type", "owner_name", "owner_owner",
|
||||
"key_hash", "type", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
tokenID, "test-app", "user", "test-user", "test-owner",
|
||||
"test-hash", "user", now, now,
|
||||
)
|
||||
mock.ExpectQuery(`SELECT (.+) FROM static_tokens WHERE id = \$1`).
|
||||
WithArgs(tokenID).
|
||||
WillReturnRows(rows)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := repo.GetByID(ctx, tokenID)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user