-
This commit is contained in:
@ -14,231 +14,7 @@ import (
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
)
|
||||
|
||||
// MockConfig implements ConfigProvider for testing
|
||||
type MockConfig struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func NewMockConfig() *MockConfig {
|
||||
return &MockConfig{
|
||||
values: map[string]string{
|
||||
"JWT_SECRET": "test-jwt-secret-for-testing-only",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetString(key string) string {
|
||||
return m.values[key]
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetInt(key string) int { return 0 }
|
||||
func (m *MockConfig) GetBool(key string) bool { return false }
|
||||
func (m *MockConfig) GetDuration(key string) time.Duration { return 0 }
|
||||
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
|
||||
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
|
||||
func (m *MockConfig) Validate() error { return nil }
|
||||
func (m *MockConfig) GetDatabaseDSN() string { return "" }
|
||||
func (m *MockConfig) GetServerAddress() string { return "" }
|
||||
func (m *MockConfig) GetMetricsAddress() string { return "" }
|
||||
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
|
||||
func (m *MockConfig) IsDevelopment() bool { return true }
|
||||
func (m *MockConfig) IsProduction() bool { return false }
|
||||
|
||||
func TestJWTManager_GenerateToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenString)
|
||||
|
||||
// Verify the token can be validated
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
assert.Equal(t, userToken.Permissions, claims.Permissions)
|
||||
assert.Equal(t, userToken.TokenType, claims.TokenType)
|
||||
assert.Equal(t, userToken.Claims, claims.Claims)
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test valid token
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
|
||||
// Test invalid token
|
||||
_, err = jwtManager.ValidateToken("invalid-token")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test empty token
|
||||
_, err = jwtManager.ValidateToken("")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_ExpiredToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
// Create an expired token
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired 1 hour ago
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validation should fail for expired token
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_MaxValidAtExpired(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
// Create a token that's past max valid time
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(-time.Hour), // Max valid time expired
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validation should fail for token past max valid time
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_RefreshToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
originalToken, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh the token
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
refreshedToken, err := jwtManager.RefreshToken(originalToken, newExpiration)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshedToken)
|
||||
assert.NotEqual(t, originalToken, refreshedToken)
|
||||
|
||||
// Validate the refreshed token
|
||||
claims, err := jwtManager.ValidateToken(refreshedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_ExtractClaims(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired token
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract claims from expired token (should work)
|
||||
claims, err := jwtManager.ExtractClaims(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_GetTokenInfo(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
info := jwtManager.GetTokenInfo(tokenString)
|
||||
assert.Equal(t, userToken.UserID, info["user_id"])
|
||||
assert.Equal(t, userToken.AppID, info["app_id"])
|
||||
assert.Equal(t, userToken.Permissions, info["permissions"])
|
||||
assert.Equal(t, userToken.TokenType, info["token_type"])
|
||||
}
|
||||
|
||||
func TestAuthenticationService_ValidateJWTToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
@ -330,7 +106,8 @@ func TestAuthenticationService_RefreshJWTToken(t *testing.T) {
|
||||
|
||||
func TestJWTManager_InvalidSecret(t *testing.T) {
|
||||
// Test with empty JWT secret
|
||||
config := &MockConfig{values: map[string]string{"JWT_SECRET": ""}}
|
||||
config := NewTestConfig()
|
||||
config.values["JWT_SECRET"] = ""
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
|
||||
@ -12,48 +12,6 @@ import (
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
)
|
||||
|
||||
// MockConfig implements ConfigProvider for testing
|
||||
type MockConfig struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func NewMockConfig() *MockConfig {
|
||||
return &MockConfig{
|
||||
values: map[string]string{
|
||||
"CACHE_ENABLED": "true",
|
||||
"CACHE_TTL": "1h",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetString(key string) string {
|
||||
return m.values[key]
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetInt(key string) int { return 0 }
|
||||
func (m *MockConfig) GetBool(key string) bool {
|
||||
if key == "CACHE_ENABLED" {
|
||||
return m.values[key] == "true"
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (m *MockConfig) GetDuration(key string) time.Duration {
|
||||
if key == "CACHE_TTL" {
|
||||
if d, err := time.ParseDuration(m.values[key]); err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
|
||||
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
|
||||
func (m *MockConfig) Validate() error { return nil }
|
||||
func (m *MockConfig) GetDatabaseDSN() string { return "" }
|
||||
func (m *MockConfig) GetServerAddress() string { return "" }
|
||||
func (m *MockConfig) GetMetricsAddress() string { return "" }
|
||||
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
|
||||
func (m *MockConfig) IsDevelopment() bool { return true }
|
||||
func (m *MockConfig) IsProduction() bool { return false }
|
||||
|
||||
func TestMemoryCache_SetAndGet(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
@ -315,12 +273,9 @@ func TestCacheKeyPrefixes(t *testing.T) {
|
||||
|
||||
func TestCacheManager_ConfigMethods(t *testing.T) {
|
||||
// Create mock config with cache settings
|
||||
config := &MockConfig{
|
||||
values: map[string]string{
|
||||
"CACHE_ENABLED": "true",
|
||||
"CACHE_TTL": "1h",
|
||||
},
|
||||
}
|
||||
config := NewMockConfig()
|
||||
config.values["CACHE_ENABLED"] = "true"
|
||||
config.values["CACHE_TTL"] = "1h"
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
382
test/jwt_test.go
Normal file
382
test/jwt_test.go
Normal file
@ -0,0 +1,382 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
)
|
||||
|
||||
func TestJWTManager_GenerateToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"department": "engineering",
|
||||
"role": "developer",
|
||||
},
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenString)
|
||||
|
||||
// Verify token structure (should have 3 parts separated by dots)
|
||||
parts := len(tokenString)
|
||||
assert.Greater(t, parts, 100) // JWT tokens are typically longer than 100 chars
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"department": "engineering",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
assert.Equal(t, userToken.Permissions, claims.Permissions)
|
||||
assert.Equal(t, userToken.TokenType, claims.TokenType)
|
||||
assert.Equal(t, userToken.Claims, claims.Claims)
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken_InvalidToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
// Test with invalid token
|
||||
_, err := jwtManager.ValidateToken("invalid.token.here")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid token")
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken_ExpiredToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired 1 hour ago
|
||||
MaxValidAt: time.Now().Add(-30 * time.Minute), // Max valid also expired
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token (this should work even with past dates)
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token (this should fail due to expiration)
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
// The error could be either JWT expiration or our custom max valid check
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "expired beyond maximum validity") ||
|
||||
strings.Contains(err.Error(), "token is expired"),
|
||||
"Expected expiration error, got: %s", err.Error())
|
||||
}
|
||||
|
||||
func TestJWTManager_RefreshToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate original token
|
||||
originalToken, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh token with new expiration
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
refreshedToken, err := jwtManager.RefreshToken(originalToken, newExpiration)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshedToken)
|
||||
assert.NotEqual(t, originalToken, refreshedToken)
|
||||
|
||||
// Validate refreshed token
|
||||
claims, err := jwtManager.ValidateToken(refreshedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_RefreshToken_ExpiredMaxValid(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(-30 * time.Minute), // Max valid expired
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to refresh (should fail due to max valid expiration)
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
_, err = jwtManager.RefreshToken(tokenString, newExpiration)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired beyond maximum validity")
|
||||
}
|
||||
|
||||
func TestJWTManager_ExtractClaims(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired token
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate expired token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract claims (should work even for expired tokens)
|
||||
claims, err := jwtManager.ExtractClaims(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
assert.Equal(t, userToken.Permissions, claims.Permissions)
|
||||
}
|
||||
|
||||
func TestJWTManager_RevokeToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Revoke token
|
||||
err = jwtManager.RevokeToken(tokenString)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check if token is revoked
|
||||
revoked, err := jwtManager.IsTokenRevoked(tokenString)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, revoked)
|
||||
}
|
||||
|
||||
func TestJWTManager_RevokeToken_AlreadyExpired(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Already expired
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate expired token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Revoke expired token (should succeed but not add to blacklist)
|
||||
err = jwtManager.RevokeToken(tokenString)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check if token is revoked (should be false since it was already expired)
|
||||
revoked, err := jwtManager.IsTokenRevoked(tokenString)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, revoked)
|
||||
}
|
||||
|
||||
func TestJWTManager_IsTokenRevoked_NotRevoked(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check if token is revoked (should be false)
|
||||
revoked, err := jwtManager.IsTokenRevoked(tokenString)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, revoked)
|
||||
}
|
||||
|
||||
func TestJWTManager_GetTokenInfo(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"department": "engineering",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get token info
|
||||
info := jwtManager.GetTokenInfo(tokenString)
|
||||
assert.Equal(t, userToken.UserID, info["user_id"])
|
||||
assert.Equal(t, userToken.AppID, info["app_id"])
|
||||
assert.Equal(t, userToken.Permissions, info["permissions"])
|
||||
assert.Equal(t, userToken.TokenType, info["token_type"])
|
||||
assert.NotNil(t, info["issued_at"])
|
||||
assert.NotNil(t, info["expires_at"])
|
||||
assert.NotNil(t, info["max_valid_at"])
|
||||
assert.NotNil(t, info["jti"])
|
||||
}
|
||||
|
||||
func TestJWTManager_GetTokenInfo_InvalidToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
// Get info for invalid token
|
||||
info := jwtManager.GetTokenInfo("invalid.token.here")
|
||||
assert.Contains(t, info["error"], "Invalid token format")
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkJWTManager_GenerateToken(b *testing.B) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := jwtManager.GenerateToken(userToken)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJWTManager_ValidateToken(b *testing.B) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := jwtManager.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
552
test/oauth2_test.go
Normal file
552
test/oauth2_test.go
Normal file
@ -0,0 +1,552 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
)
|
||||
|
||||
func TestOAuth2Provider_GetDiscoveryDocument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providerURL string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedIssuer string
|
||||
}{
|
||||
{
|
||||
name: "successful discovery",
|
||||
providerURL: "https://example.com",
|
||||
mockResponse: `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo",
|
||||
"jwks_uri": "https://example.com/jwks"
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedIssuer: "https://example.com",
|
||||
},
|
||||
{
|
||||
name: "missing provider URL",
|
||||
providerURL: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid response status",
|
||||
providerURL: "https://example.com",
|
||||
mockResponse: `{"error": "not found"}`,
|
||||
mockStatusCode: http.StatusNotFound,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON response",
|
||||
providerURL: "https://example.com",
|
||||
mockResponse: `invalid json`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock server if needed
|
||||
var server *httptest.Server
|
||||
if tt.providerURL != "" && !tt.expectError {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/.well-known/openid_configuration", r.URL.Path)
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer server.Close()
|
||||
tt.providerURL = server.URL
|
||||
}
|
||||
|
||||
// Create config mock
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = tt.providerURL
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
discovery, err := provider.GetDiscoveryDocument(ctx)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, discovery)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, discovery)
|
||||
assert.Equal(t, tt.expectedIssuer, discovery.Issuer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_GenerateAuthURL(t *testing.T) {
|
||||
// Create mock discovery server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
state string
|
||||
redirectURI string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful URL generation",
|
||||
clientID: "test-client-id",
|
||||
state: "test-state",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
clientID: "",
|
||||
state: "test-state",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
authURL, err := provider.GenerateAuthURL(ctx, tt.state, tt.redirectURI)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, authURL)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, authURL)
|
||||
assert.Contains(t, authURL, "https://example.com/auth")
|
||||
assert.Contains(t, authURL, "client_id="+tt.clientID)
|
||||
assert.Contains(t, authURL, "state="+tt.state)
|
||||
assert.Contains(t, authURL, "redirect_uri=")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_ExchangeCodeForToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
redirectURI string
|
||||
codeVerifier string
|
||||
clientID string
|
||||
clientSecret string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedToken string
|
||||
}{
|
||||
{
|
||||
name: "successful token exchange",
|
||||
code: "test-code",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
codeVerifier: "test-verifier",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{
|
||||
"access_token": "test-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "test-refresh-token"
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedToken: "test-access-token",
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
code: "test-code",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
codeVerifier: "test-verifier",
|
||||
clientID: "",
|
||||
clientSecret: "test-client-secret",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "token endpoint error",
|
||||
code: "test-code",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
codeVerifier: "test-verifier",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{"error": "invalid_grant"}`,
|
||||
mockStatusCode: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock servers
|
||||
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer discoveryServer.Close()
|
||||
|
||||
var tokenServer *httptest.Server
|
||||
if !tt.expectError {
|
||||
tokenServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
||||
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
// Update discovery server to return the token server URL
|
||||
discoveryServer.Close()
|
||||
discoveryServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "` + tokenServer.URL + `",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
}
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
||||
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
tokenResp, err := provider.ExchangeCodeForToken(ctx, tt.code, tt.redirectURI, tt.codeVerifier)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokenResp)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokenResp)
|
||||
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
|
||||
assert.Equal(t, "Bearer", tokenResp.TokenType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_GetUserInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedSub string
|
||||
expectedEmail string
|
||||
}{
|
||||
{
|
||||
name: "successful user info retrieval",
|
||||
accessToken: "test-access-token",
|
||||
mockResponse: `{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
"email_verified": true
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedSub: "user123",
|
||||
expectedEmail: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "unauthorized access token",
|
||||
accessToken: "invalid-token",
|
||||
mockResponse: `{"error": "invalid_token"}`,
|
||||
mockStatusCode: http.StatusUnauthorized,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON response",
|
||||
accessToken: "test-access-token",
|
||||
mockResponse: `invalid json`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock servers
|
||||
userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
assert.Equal(t, "Bearer "+tt.accessToken, r.Header.Get("Authorization"))
|
||||
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer userInfoServer.Close()
|
||||
|
||||
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "` + userInfoServer.URL + `"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer discoveryServer.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
userInfo, err := provider.GetUserInfo(ctx, tt.accessToken)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, userInfo)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, userInfo)
|
||||
assert.Equal(t, tt.expectedSub, userInfo.Sub)
|
||||
assert.Equal(t, tt.expectedEmail, userInfo.Email)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_ValidateIDToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
expectError bool
|
||||
expectedSub string
|
||||
}{
|
||||
{
|
||||
name: "valid ID token",
|
||||
// This is a mock JWT token with payload: {"sub": "user123", "email": "user@example.com", "name": "Test User"}
|
||||
idToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIiwiZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwibmFtZSI6IlRlc3QgVXNlciJ9.invalid-signature",
|
||||
expectError: false,
|
||||
expectedSub: "user123",
|
||||
},
|
||||
{
|
||||
name: "invalid token format",
|
||||
idToken: "invalid.token",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
idToken: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configMock := NewMockConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ValidateIDToken(ctx, tt.idToken)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
assert.Equal(t, tt.expectedSub, authContext.UserID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_RefreshAccessToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
clientID string
|
||||
clientSecret string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedToken string
|
||||
}{
|
||||
{
|
||||
name: "successful token refresh",
|
||||
refreshToken: "test-refresh-token",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "new-refresh-token"
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedToken: "new-access-token",
|
||||
},
|
||||
{
|
||||
name: "invalid refresh token",
|
||||
refreshToken: "invalid-refresh-token",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{"error": "invalid_grant"}`,
|
||||
mockStatusCode: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock servers
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
||||
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "` + tokenServer.URL + `",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer discoveryServer.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
||||
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
tokenResp, err := provider.RefreshAccessToken(ctx, tt.refreshToken)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokenResp)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokenResp)
|
||||
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
|
||||
assert.Equal(t, "Bearer", tokenResp.TokenType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for OAuth2 operations
|
||||
func BenchmarkOAuth2Provider_GetDiscoveryDocument(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.GetDiscoveryDocument(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOAuth2Provider_GenerateAuthURL(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = "test-client-id"
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.GenerateAuthURL(ctx, "test-state", "https://app.example.com/callback")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
594
test/permissions_test.go
Normal file
594
test/permissions_test.go
Normal file
@ -0,0 +1,594 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
)
|
||||
|
||||
func TestPermissionHierarchy_InitializeDefaultPermissions(t *testing.T) {
|
||||
hierarchy := auth.NewPermissionHierarchy()
|
||||
|
||||
// Test that default permissions are created
|
||||
permissions := hierarchy.ListPermissions()
|
||||
assert.NotEmpty(t, permissions)
|
||||
|
||||
// Test specific permissions exist
|
||||
permissionNames := make(map[string]bool)
|
||||
for _, perm := range permissions {
|
||||
permissionNames[perm.Name] = true
|
||||
}
|
||||
|
||||
expectedPermissions := []string{
|
||||
"admin", "read", "write",
|
||||
"app.admin", "app.read", "app.write", "app.create", "app.update", "app.delete",
|
||||
"token.admin", "token.read", "token.write", "token.create", "token.revoke", "token.verify",
|
||||
"permission.admin", "permission.read", "permission.write", "permission.grant", "permission.revoke",
|
||||
"user.admin", "user.read", "user.write",
|
||||
}
|
||||
|
||||
for _, expected := range expectedPermissions {
|
||||
assert.True(t, permissionNames[expected], "Permission %s should exist", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHierarchy_InitializeDefaultRoles(t *testing.T) {
|
||||
hierarchy := auth.NewPermissionHierarchy()
|
||||
|
||||
// Test that default roles are created
|
||||
roles := hierarchy.ListRoles()
|
||||
assert.NotEmpty(t, roles)
|
||||
|
||||
// Test specific roles exist
|
||||
roleNames := make(map[string]bool)
|
||||
for _, role := range roles {
|
||||
roleNames[role.Name] = true
|
||||
}
|
||||
|
||||
expectedRoles := []string{
|
||||
"super_admin", "app_admin", "developer", "viewer", "token_manager",
|
||||
}
|
||||
|
||||
for _, expected := range expectedRoles {
|
||||
assert.True(t, roleNames[expected], "Role %s should exist", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_HasPermission(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false" // Disable cache for testing
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
appID string
|
||||
permission string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "admin user has admin permission",
|
||||
userID: "admin@example.com",
|
||||
appID: "test-app",
|
||||
permission: "admin",
|
||||
expectedResult: true,
|
||||
description: "Admin users should have admin permissions",
|
||||
},
|
||||
{
|
||||
name: "developer user has token.create permission",
|
||||
userID: "dev@example.com",
|
||||
appID: "test-app",
|
||||
permission: "token.create",
|
||||
expectedResult: true,
|
||||
description: "Developer users should have token creation permissions",
|
||||
},
|
||||
{
|
||||
name: "viewer user has read permission",
|
||||
userID: "viewer@example.com",
|
||||
appID: "test-app",
|
||||
permission: "app.read",
|
||||
expectedResult: true,
|
||||
description: "Viewer users should have read permissions",
|
||||
},
|
||||
{
|
||||
name: "viewer user denied write permission",
|
||||
userID: "viewer@example.com",
|
||||
appID: "test-app",
|
||||
permission: "app.write",
|
||||
expectedResult: false,
|
||||
description: "Viewer users should not have write permissions",
|
||||
},
|
||||
{
|
||||
name: "non-existent permission",
|
||||
userID: "admin@example.com",
|
||||
appID: "test-app",
|
||||
permission: "non.existent",
|
||||
expectedResult: false,
|
||||
description: "Non-existent permissions should be denied",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
evaluation, err := pm.HasPermission(ctx, tt.userID, tt.appID, tt.permission)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, evaluation)
|
||||
assert.Equal(t, tt.expectedResult, evaluation.Granted, tt.description)
|
||||
assert.Equal(t, tt.permission, evaluation.Permission)
|
||||
assert.NotZero(t, evaluation.EvaluatedAt)
|
||||
|
||||
if evaluation.Granted {
|
||||
assert.NotEmpty(t, evaluation.GrantedBy, "Granted permissions should have GrantedBy information")
|
||||
} else {
|
||||
assert.NotEmpty(t, evaluation.DeniedReason, "Denied permissions should have a reason")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_EvaluateBulkPermissions(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
req := &auth.BulkPermissionRequest{
|
||||
UserID: "dev@example.com",
|
||||
AppID: "test-app",
|
||||
Permissions: []string{
|
||||
"app.read",
|
||||
"token.create",
|
||||
"token.read",
|
||||
"app.delete", // Should be denied for developer
|
||||
"admin", // Should be denied for developer
|
||||
},
|
||||
}
|
||||
|
||||
response, err := pm.EvaluateBulkPermissions(ctx, req)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.Equal(t, req.UserID, response.UserID)
|
||||
assert.Equal(t, req.AppID, response.AppID)
|
||||
assert.Len(t, response.Results, len(req.Permissions))
|
||||
|
||||
// Check specific results
|
||||
assert.True(t, response.Results["app.read"].Granted, "Developer should have app.read permission")
|
||||
assert.True(t, response.Results["token.create"].Granted, "Developer should have token.create permission")
|
||||
assert.True(t, response.Results["token.read"].Granted, "Developer should have token.read permission")
|
||||
assert.False(t, response.Results["app.delete"].Granted, "Developer should not have app.delete permission")
|
||||
assert.False(t, response.Results["admin"].Granted, "Developer should not have admin permission")
|
||||
}
|
||||
|
||||
func TestPermissionManager_AddPermission(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
permission *auth.Permission
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "add valid permission",
|
||||
permission: &auth.Permission{
|
||||
Name: "custom.permission",
|
||||
Description: "Custom permission for testing",
|
||||
Parent: "read",
|
||||
Level: 2,
|
||||
Resource: "custom",
|
||||
Action: "test",
|
||||
},
|
||||
expectError: false,
|
||||
description: "Valid permissions should be added successfully",
|
||||
},
|
||||
{
|
||||
name: "add permission without name",
|
||||
permission: &auth.Permission{
|
||||
Description: "Permission without name",
|
||||
Parent: "read",
|
||||
Level: 2,
|
||||
},
|
||||
expectError: true,
|
||||
description: "Permissions without names should be rejected",
|
||||
},
|
||||
{
|
||||
name: "add permission with non-existent parent",
|
||||
permission: &auth.Permission{
|
||||
Name: "invalid.permission",
|
||||
Description: "Permission with invalid parent",
|
||||
Parent: "non.existent",
|
||||
Level: 2,
|
||||
},
|
||||
expectError: true,
|
||||
description: "Permissions with non-existent parents should be rejected",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := pm.AddPermission(tt.permission)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
} else {
|
||||
assert.NoError(t, err, tt.description)
|
||||
|
||||
// Verify permission was added
|
||||
permissions := pm.ListPermissions()
|
||||
found := false
|
||||
for _, perm := range permissions {
|
||||
if perm.Name == tt.permission.Name {
|
||||
found = true
|
||||
assert.Equal(t, tt.permission.Description, perm.Description)
|
||||
assert.Equal(t, tt.permission.Parent, perm.Parent)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Added permission should be found in the list")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_AddRole(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
role *auth.Role
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "add valid role",
|
||||
role: &auth.Role{
|
||||
Name: "custom_role",
|
||||
Description: "Custom role for testing",
|
||||
Permissions: []string{"read", "app.read"},
|
||||
Metadata: map[string]string{"level": "custom"},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Valid roles should be added successfully",
|
||||
},
|
||||
{
|
||||
name: "add role without name",
|
||||
role: &auth.Role{
|
||||
Description: "Role without name",
|
||||
Permissions: []string{"read"},
|
||||
},
|
||||
expectError: true,
|
||||
description: "Roles without names should be rejected",
|
||||
},
|
||||
{
|
||||
name: "add role with non-existent permission",
|
||||
role: &auth.Role{
|
||||
Name: "invalid_role",
|
||||
Description: "Role with invalid permission",
|
||||
Permissions: []string{"non.existent.permission"},
|
||||
},
|
||||
expectError: true,
|
||||
description: "Roles with non-existent permissions should be rejected",
|
||||
},
|
||||
{
|
||||
name: "add role with non-existent inherited role",
|
||||
role: &auth.Role{
|
||||
Name: "invalid_inherited_role",
|
||||
Description: "Role with invalid inheritance",
|
||||
Permissions: []string{"read"},
|
||||
Inherits: []string{"non_existent_role"},
|
||||
},
|
||||
expectError: true,
|
||||
description: "Roles with non-existent inherited roles should be rejected",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := pm.AddRole(tt.role)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
} else {
|
||||
assert.NoError(t, err, tt.description)
|
||||
|
||||
// Verify role was added
|
||||
roles := pm.ListRoles()
|
||||
found := false
|
||||
for _, role := range roles {
|
||||
if role.Name == tt.role.Name {
|
||||
found = true
|
||||
assert.Equal(t, tt.role.Description, role.Description)
|
||||
assert.Equal(t, tt.role.Permissions, role.Permissions)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Added role should be found in the list")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_ListPermissions(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
permissions := pm.ListPermissions()
|
||||
|
||||
// Should have default permissions
|
||||
assert.NotEmpty(t, permissions)
|
||||
|
||||
// Should be sorted by level and name
|
||||
for i := 1; i < len(permissions); i++ {
|
||||
prev := permissions[i-1]
|
||||
curr := permissions[i]
|
||||
|
||||
if prev.Level == curr.Level {
|
||||
assert.True(t, prev.Name <= curr.Name, "Permissions at same level should be sorted by name")
|
||||
} else {
|
||||
assert.True(t, prev.Level <= curr.Level, "Permissions should be sorted by level")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify hierarchy structure
|
||||
for _, perm := range permissions {
|
||||
if perm.Parent != "" {
|
||||
// Find parent permission
|
||||
parentFound := false
|
||||
for _, parent := range permissions {
|
||||
if parent.Name == perm.Parent {
|
||||
parentFound = true
|
||||
assert.True(t, parent.Level < perm.Level, "Parent should have lower level than child")
|
||||
assert.Contains(t, parent.Children, perm.Name, "Parent should contain child in children list")
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, parentFound, "Parent permission should exist for %s", perm.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_ListRoles(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
roles := pm.ListRoles()
|
||||
|
||||
// Should have default roles
|
||||
assert.NotEmpty(t, roles)
|
||||
|
||||
// Should be sorted by name
|
||||
for i := 1; i < len(roles); i++ {
|
||||
assert.True(t, roles[i-1].Name <= roles[i].Name, "Roles should be sorted by name")
|
||||
}
|
||||
|
||||
// Verify all permissions in roles exist
|
||||
allPermissions := pm.ListPermissions()
|
||||
permissionNames := make(map[string]bool)
|
||||
for _, perm := range allPermissions {
|
||||
permissionNames[perm.Name] = true
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
for _, perm := range role.Permissions {
|
||||
assert.True(t, permissionNames[perm], "Role %s references non-existent permission %s", role.Name, perm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_InvalidatePermissionCache(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
err := pm.InvalidatePermissionCache(ctx, "user123", "app123")
|
||||
|
||||
// Should not error (currently just logs)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPermissionHierarchy_BuildHierarchy(t *testing.T) {
|
||||
hierarchy := auth.NewPermissionHierarchy()
|
||||
|
||||
// Test that parent-child relationships are built correctly
|
||||
permissions := hierarchy.ListPermissions()
|
||||
|
||||
// Find admin permission
|
||||
var adminPerm *auth.Permission
|
||||
for _, perm := range permissions {
|
||||
if perm.Name == "admin" {
|
||||
adminPerm = perm
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, adminPerm, "Admin permission should exist")
|
||||
|
||||
// Admin should have children
|
||||
assert.NotEmpty(t, adminPerm.Children, "Admin permission should have children")
|
||||
|
||||
// Check that app.admin is a child of admin
|
||||
assert.Contains(t, adminPerm.Children, "app.admin", "app.admin should be a child of admin")
|
||||
|
||||
// Find app.write permission
|
||||
var appWritePerm *auth.Permission
|
||||
for _, perm := range permissions {
|
||||
if perm.Name == "app.write" {
|
||||
appWritePerm = perm
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, appWritePerm, "app.write permission should exist")
|
||||
|
||||
// app.write should have children
|
||||
assert.NotEmpty(t, appWritePerm.Children, "app.write permission should have children")
|
||||
assert.Contains(t, appWritePerm.Children, "app.create", "app.create should be a child of app.write")
|
||||
assert.Contains(t, appWritePerm.Children, "app.update", "app.update should be a child of app.write")
|
||||
assert.Contains(t, appWritePerm.Children, "app.delete", "app.delete should be a child of app.write")
|
||||
}
|
||||
|
||||
// Benchmark tests for permission operations
|
||||
func BenchmarkPermissionManager_HasPermission(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := pm.HasPermission(ctx, "dev@example.com", "test-app", "token.create")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPermissionManager_EvaluateBulkPermissions(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &auth.BulkPermissionRequest{
|
||||
UserID: "dev@example.com",
|
||||
AppID: "test-app",
|
||||
Permissions: []string{
|
||||
"app.read", "token.create", "token.read", "app.delete", "admin",
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := pm.EvaluateBulkPermissions(ctx, req)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPermissionManager_ListPermissions(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
permissions := pm.ListPermissions()
|
||||
if len(permissions) == 0 {
|
||||
b.Fatal("No permissions returned")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPermissionManager_ListRoles(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
roles := pm.ListRoles()
|
||||
if len(roles) == 0 {
|
||||
b.Fatal("No roles returned")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test permission hierarchy traversal
|
||||
func TestPermissionHierarchy_PermissionInheritance(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
// Test that admin users get hierarchical permissions
|
||||
ctx := context.Background()
|
||||
|
||||
// Admin should have all permissions through hierarchy
|
||||
adminPermissions := []string{
|
||||
"admin",
|
||||
"app.admin",
|
||||
"token.admin",
|
||||
"permission.admin",
|
||||
"user.admin",
|
||||
}
|
||||
|
||||
for _, perm := range adminPermissions {
|
||||
evaluation, err := pm.HasPermission(ctx, "admin@example.com", "test-app", perm)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, evaluation.Granted, "Admin should have %s permission", perm)
|
||||
}
|
||||
}
|
||||
|
||||
// Test role inheritance
|
||||
func TestPermissionManager_RoleInheritance(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
// Add a role that inherits from another role
|
||||
parentRole := &auth.Role{
|
||||
Name: "base_role",
|
||||
Description: "Base role with basic permissions",
|
||||
Permissions: []string{"read", "app.read"},
|
||||
Metadata: map[string]string{"level": "base"},
|
||||
}
|
||||
|
||||
childRole := &auth.Role{
|
||||
Name: "extended_role",
|
||||
Description: "Extended role that inherits from base",
|
||||
Permissions: []string{"write"},
|
||||
Inherits: []string{"base_role"},
|
||||
Metadata: map[string]string{"level": "extended"},
|
||||
}
|
||||
|
||||
err := pm.AddRole(parentRole)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pm.AddRole(childRole)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify roles were added
|
||||
roles := pm.ListRoles()
|
||||
roleNames := make(map[string]*auth.Role)
|
||||
for _, role := range roles {
|
||||
roleNames[role.Name] = role
|
||||
}
|
||||
|
||||
assert.Contains(t, roleNames, "base_role")
|
||||
assert.Contains(t, roleNames, "extended_role")
|
||||
assert.Equal(t, []string{"base_role"}, roleNames["extended_role"].Inherits)
|
||||
}
|
||||
@ -29,6 +29,10 @@ func (c *TestConfig) GetBool(key string) bool {
|
||||
return boolVal
|
||||
}
|
||||
}
|
||||
// Special handling for cache enabled
|
||||
if key == "CACHE_ENABLED" {
|
||||
return c.values[key] == "true"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@ -86,6 +90,10 @@ func (c *TestConfig) IsProduction() bool {
|
||||
return c.GetString("APP_ENV") == "production"
|
||||
}
|
||||
|
||||
func (c *TestConfig) GetJWTSecret() string {
|
||||
return c.GetString("JWT_SECRET")
|
||||
}
|
||||
|
||||
// NewTestConfig creates a test configuration with default values
|
||||
func NewTestConfig() *TestConfig {
|
||||
return &TestConfig{
|
||||
@ -99,6 +107,12 @@ func NewTestConfig() *TestConfig {
|
||||
"SERVER_HOST": "localhost",
|
||||
"SERVER_PORT": "8080",
|
||||
"APP_ENV": "test",
|
||||
"JWT_SECRET": "test-jwt-secret-for-testing-only",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockConfig creates a mock configuration (alias for NewTestConfig for backward compatibility)
|
||||
func NewMockConfig() *TestConfig {
|
||||
return NewTestConfig()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user