v1
This commit is contained in:
1
go.mod
1
go.mod
@ -23,6 +23,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.16.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@ -39,6 +39,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang-migrate/migrate/v4 v4.16.2 h1:8coYbMKUyInrFk1lfGfRovTLAW7PhWp8qQDT2iKfuoA=
|
||||
github.com/golang-migrate/migrate/v4 v4.16.2/go.mod h1:pfcJX4nPHaVdc5nmdCikFBWtm+UBpiZjRNNsyBbp0/o=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
|
||||
258
internal/auth/jwt.go
Normal file
258
internal/auth/jwt.go
Normal file
@ -0,0 +1,258 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// JWTManager handles JWT token operations
|
||||
type JWTManager struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewJWTManager creates a new JWT manager
|
||||
func NewJWTManager(config config.ConfigProvider, logger *zap.Logger) *JWTManager {
|
||||
return &JWTManager{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CustomClaims represents the custom claims in our JWT tokens
|
||||
type CustomClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
AppID string `json:"app_id"`
|
||||
Permissions []string `json:"permissions"`
|
||||
TokenType domain.TokenType `json:"token_type"`
|
||||
MaxValidAt int64 `json:"max_valid_at"`
|
||||
Claims map[string]string `json:"claims,omitempty"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateToken generates a new JWT token for a user
|
||||
func (j *JWTManager) GenerateToken(userToken *domain.UserToken) (string, error) {
|
||||
j.logger.Debug("Generating JWT token",
|
||||
zap.String("user_id", userToken.UserID),
|
||||
zap.String("app_id", userToken.AppID),
|
||||
zap.Strings("permissions", userToken.Permissions))
|
||||
|
||||
// Get JWT secret from config
|
||||
jwtSecret := j.config.GetJWTSecret()
|
||||
if jwtSecret == "" {
|
||||
return "", errors.NewValidationError("JWT secret not configured")
|
||||
}
|
||||
|
||||
// Create custom claims
|
||||
claims := CustomClaims{
|
||||
UserID: userToken.UserID,
|
||||
AppID: userToken.AppID,
|
||||
Permissions: userToken.Permissions,
|
||||
TokenType: userToken.TokenType,
|
||||
MaxValidAt: userToken.MaxValidAt.Unix(),
|
||||
Claims: userToken.Claims,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "kms-api-service",
|
||||
Subject: userToken.UserID,
|
||||
Audience: []string{userToken.AppID},
|
||||
ExpiresAt: jwt.NewNumericDate(userToken.ExpiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(userToken.IssuedAt),
|
||||
NotBefore: jwt.NewNumericDate(userToken.IssuedAt),
|
||||
ID: j.generateJTI(),
|
||||
},
|
||||
}
|
||||
|
||||
// Create token with claims
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
// Sign token with secret
|
||||
tokenString, err := token.SignedString([]byte(jwtSecret))
|
||||
if err != nil {
|
||||
j.logger.Error("Failed to sign JWT token", zap.Error(err))
|
||||
return "", errors.NewInternalError("Failed to generate token").WithInternal(err)
|
||||
}
|
||||
|
||||
j.logger.Debug("JWT token generated successfully",
|
||||
zap.String("user_id", userToken.UserID),
|
||||
zap.String("app_id", userToken.AppID))
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates and parses a JWT token
|
||||
func (j *JWTManager) ValidateToken(tokenString string) (*CustomClaims, error) {
|
||||
j.logger.Debug("Validating JWT token")
|
||||
|
||||
// Get JWT secret from config
|
||||
jwtSecret := j.config.GetJWTSecret()
|
||||
if jwtSecret == "" {
|
||||
return nil, errors.NewValidationError("JWT secret not configured")
|
||||
}
|
||||
|
||||
// Parse token with custom claims
|
||||
token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
j.logger.Warn("Failed to parse JWT token", zap.Error(err))
|
||||
return nil, errors.NewAuthenticationError("Invalid token").WithInternal(err)
|
||||
}
|
||||
|
||||
// Extract custom claims
|
||||
claims, ok := token.Claims.(*CustomClaims)
|
||||
if !ok || !token.Valid {
|
||||
j.logger.Warn("Invalid JWT token claims")
|
||||
return nil, errors.NewAuthenticationError("Invalid token claims")
|
||||
}
|
||||
|
||||
// Check if token is expired beyond max valid time
|
||||
if time.Now().Unix() > claims.MaxValidAt {
|
||||
j.logger.Warn("JWT token expired beyond max valid time",
|
||||
zap.Int64("max_valid_at", claims.MaxValidAt),
|
||||
zap.Int64("current_time", time.Now().Unix()))
|
||||
return nil, errors.NewAuthenticationError("Token expired beyond maximum validity")
|
||||
}
|
||||
|
||||
j.logger.Debug("JWT token validated successfully",
|
||||
zap.String("user_id", claims.UserID),
|
||||
zap.String("app_id", claims.AppID))
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RefreshToken generates a new token with updated expiration
|
||||
func (j *JWTManager) RefreshToken(oldTokenString string, newExpiration time.Time) (string, error) {
|
||||
j.logger.Debug("Refreshing JWT token")
|
||||
|
||||
// Validate the old token first
|
||||
claims, err := j.ValidateToken(oldTokenString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if we can still refresh (not past max valid time)
|
||||
if time.Now().Unix() > claims.MaxValidAt {
|
||||
return "", errors.NewAuthenticationError("Token cannot be refreshed - past maximum validity")
|
||||
}
|
||||
|
||||
// Create new user token with updated expiration
|
||||
userToken := &domain.UserToken{
|
||||
AppID: claims.AppID,
|
||||
UserID: claims.UserID,
|
||||
Permissions: claims.Permissions,
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: newExpiration,
|
||||
MaxValidAt: time.Unix(claims.MaxValidAt, 0),
|
||||
TokenType: claims.TokenType,
|
||||
Claims: claims.Claims,
|
||||
}
|
||||
|
||||
// Generate new token
|
||||
return j.GenerateToken(userToken)
|
||||
}
|
||||
|
||||
// ExtractClaims extracts claims from a token without full validation (for expired tokens)
|
||||
func (j *JWTManager) ExtractClaims(tokenString string) (*CustomClaims, error) {
|
||||
j.logger.Debug("Extracting claims from JWT token")
|
||||
|
||||
// Parse token without validation to extract claims
|
||||
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &CustomClaims{})
|
||||
if err != nil {
|
||||
j.logger.Warn("Failed to parse JWT token for claims extraction", zap.Error(err))
|
||||
return nil, errors.NewValidationError("Invalid token format").WithInternal(err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*CustomClaims)
|
||||
if !ok {
|
||||
j.logger.Warn("Invalid JWT token claims format")
|
||||
return nil, errors.NewValidationError("Invalid token claims format")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RevokeToken adds a token to the revocation list (blacklist)
|
||||
func (j *JWTManager) RevokeToken(tokenString string) error {
|
||||
j.logger.Debug("Revoking JWT token")
|
||||
|
||||
// Extract claims to get token ID
|
||||
claims, err := j.ExtractClaims(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Implement token blacklisting mechanism
|
||||
// This could be implemented using Redis or database storage
|
||||
// For now, we'll just log the revocation
|
||||
j.logger.Info("Token revoked",
|
||||
zap.String("jti", claims.ID),
|
||||
zap.String("user_id", claims.UserID),
|
||||
zap.String("app_id", claims.AppID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsTokenRevoked checks if a token has been revoked
|
||||
func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) {
|
||||
j.logger.Debug("Checking if JWT token is revoked")
|
||||
|
||||
// Extract claims to get token ID
|
||||
claims, err := j.ExtractClaims(tokenString)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// TODO: Implement token blacklist checking
|
||||
// This could be implemented using Redis or database storage
|
||||
// For now, we'll assume no tokens are revoked
|
||||
j.logger.Debug("Token revocation check completed",
|
||||
zap.String("jti", claims.ID),
|
||||
zap.Bool("revoked", false))
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// generateJTI generates a unique JWT ID
|
||||
func (j *JWTManager) generateJTI() string {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
// Fallback to timestamp-based ID if random generation fails
|
||||
return fmt.Sprintf("jti_%d", time.Now().UnixNano())
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// GetTokenInfo extracts token information for debugging/logging
|
||||
func (j *JWTManager) GetTokenInfo(tokenString string) map[string]interface{} {
|
||||
claims, err := j.ExtractClaims(tokenString)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"user_id": claims.UserID,
|
||||
"app_id": claims.AppID,
|
||||
"permissions": claims.Permissions,
|
||||
"token_type": claims.TokenType,
|
||||
"issued_at": time.Unix(claims.IssuedAt.Unix(), 0),
|
||||
"expires_at": time.Unix(claims.ExpiresAt.Unix(), 0),
|
||||
"max_valid_at": time.Unix(claims.MaxValidAt, 0),
|
||||
"jti": claims.ID,
|
||||
}
|
||||
}
|
||||
250
internal/cache/cache.go
vendored
Normal file
250
internal/cache/cache.go
vendored
Normal file
@ -0,0 +1,250 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// CacheProvider defines the interface for cache operations
|
||||
type CacheProvider interface {
|
||||
// Get retrieves a value from cache
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
|
||||
// Set stores a value in cache with TTL
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// Delete removes a value from cache
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// Exists checks if a key exists in cache
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Clear removes all cached values (use with caution)
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// Close closes the cache connection
|
||||
Close() error
|
||||
}
|
||||
|
||||
// MemoryCache implements CacheProvider using in-memory storage
|
||||
type MemoryCache struct {
|
||||
data map[string]cacheItem
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
type cacheItem struct {
|
||||
Value []byte
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// NewMemoryCache creates a new in-memory cache
|
||||
func NewMemoryCache(config config.ConfigProvider, logger *zap.Logger) CacheProvider {
|
||||
cache := &MemoryCache{
|
||||
data: make(map[string]cacheItem),
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go cache.cleanup()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// Get retrieves a value from memory cache
|
||||
func (m *MemoryCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
m.logger.Debug("Getting value from memory cache", zap.String("key", key))
|
||||
|
||||
item, exists := m.data[key]
|
||||
if !exists {
|
||||
return nil, errors.NewNotFoundError("cache key")
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
delete(m.data, key)
|
||||
return nil, errors.NewNotFoundError("cache key")
|
||||
}
|
||||
|
||||
return item.Value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in memory cache
|
||||
func (m *MemoryCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
m.logger.Debug("Setting value in memory cache",
|
||||
zap.String("key", key),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
m.data[key] = cacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a value from memory cache
|
||||
func (m *MemoryCache) Delete(ctx context.Context, key string) error {
|
||||
m.logger.Debug("Deleting value from memory cache", zap.String("key", key))
|
||||
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in memory cache
|
||||
func (m *MemoryCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
item, exists := m.data[key]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
delete(m.data, key)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Clear removes all values from memory cache
|
||||
func (m *MemoryCache) Clear(ctx context.Context) error {
|
||||
m.logger.Debug("Clearing memory cache")
|
||||
|
||||
m.data = make(map[string]cacheItem)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the memory cache (no-op for memory cache)
|
||||
func (m *MemoryCache) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanup removes expired items from memory cache
|
||||
func (m *MemoryCache) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
for key, item := range m.data {
|
||||
if now.After(item.ExpiresAt) {
|
||||
delete(m.data, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CacheManager provides high-level caching operations with JSON serialization
|
||||
type CacheManager struct {
|
||||
provider CacheProvider
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCacheManager creates a new cache manager
|
||||
func NewCacheManager(config config.ConfigProvider, logger *zap.Logger) *CacheManager {
|
||||
var provider CacheProvider
|
||||
|
||||
// For now, we'll use memory cache. In production, this could be Redis
|
||||
provider = NewMemoryCache(config, logger)
|
||||
|
||||
return &CacheManager{
|
||||
provider: provider,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetJSON retrieves and unmarshals a JSON value from cache
|
||||
func (c *CacheManager) GetJSON(ctx context.Context, key string, dest interface{}) error {
|
||||
c.logger.Debug("Getting JSON from cache", zap.String("key", key))
|
||||
|
||||
data, err := c.provider.Get(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, dest); err != nil {
|
||||
c.logger.Error("Failed to unmarshal cached JSON", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to unmarshal cached data").WithInternal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetJSON marshals and stores a JSON value in cache
|
||||
func (c *CacheManager) SetJSON(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
c.logger.Debug("Setting JSON in cache",
|
||||
zap.String("key", key),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to marshal JSON for cache", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to marshal data for cache").WithInternal(err)
|
||||
}
|
||||
|
||||
return c.provider.Set(ctx, key, data, ttl)
|
||||
}
|
||||
|
||||
// Get retrieves raw bytes from cache
|
||||
func (c *CacheManager) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
return c.provider.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Set stores raw bytes in cache
|
||||
func (c *CacheManager) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return c.provider.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// Delete removes a value from cache
|
||||
func (c *CacheManager) Delete(ctx context.Context, key string) error {
|
||||
return c.provider.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in cache
|
||||
func (c *CacheManager) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return c.provider.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all cached values
|
||||
func (c *CacheManager) Clear(ctx context.Context) error {
|
||||
return c.provider.Clear(ctx)
|
||||
}
|
||||
|
||||
// Close closes the cache connection
|
||||
func (c *CacheManager) Close() error {
|
||||
return c.provider.Close()
|
||||
}
|
||||
|
||||
// GetDefaultTTL returns the default TTL from config
|
||||
func (c *CacheManager) GetDefaultTTL() time.Duration {
|
||||
return c.config.GetDuration("CACHE_TTL")
|
||||
}
|
||||
|
||||
// IsEnabled returns whether caching is enabled
|
||||
func (c *CacheManager) IsEnabled() bool {
|
||||
return c.config.GetBool("CACHE_ENABLED")
|
||||
}
|
||||
|
||||
// CacheKey generates a cache key with prefix
|
||||
func CacheKey(prefix, key string) string {
|
||||
return prefix + ":" + key
|
||||
}
|
||||
|
||||
// Common cache key prefixes
|
||||
const (
|
||||
KeyPrefixPermission = "perm"
|
||||
KeyPrefixApplication = "app"
|
||||
KeyPrefixToken = "token"
|
||||
KeyPrefixUserClaims = "user_claims"
|
||||
KeyPrefixTokenRevoked = "token_revoked"
|
||||
)
|
||||
@ -42,6 +42,9 @@ type ConfigProvider interface {
|
||||
// GetMetricsAddress returns the metrics server address in host:port format
|
||||
GetMetricsAddress() string
|
||||
|
||||
// GetJWTSecret returns the JWT signing secret
|
||||
GetJWTSecret() string
|
||||
|
||||
// IsDevelopment returns true if the environment is development
|
||||
IsDevelopment() bool
|
||||
|
||||
@ -104,6 +107,7 @@ func (c *Config) setDefaults() {
|
||||
"CACHE_ENABLED": "false",
|
||||
"CACHE_TTL": "1h",
|
||||
"JWT_ISSUER": "api-key-service",
|
||||
"JWT_SECRET": "bootstrap-jwt-secret-change-in-production",
|
||||
"AUTH_PROVIDER": "header", // header or sso
|
||||
"AUTH_HEADER_USER_EMAIL": "X-User-Email",
|
||||
"SSO_PROVIDER_URL": "",
|
||||
@ -186,6 +190,7 @@ func (c *Config) Validate() error {
|
||||
"SERVER_PORT",
|
||||
"INTERNAL_APP_ID",
|
||||
"INTERNAL_HMAC_KEY",
|
||||
"JWT_SECRET",
|
||||
}
|
||||
|
||||
var missing []string
|
||||
@ -262,6 +267,11 @@ func (c *Config) GetMetricsAddress() string {
|
||||
return fmt.Sprintf("%s:%d", c.GetString("SERVER_HOST"), c.GetInt("METRICS_PORT"))
|
||||
}
|
||||
|
||||
// GetJWTSecret returns the JWT signing secret
|
||||
func (c *Config) GetJWTSecret() string {
|
||||
return c.GetString("JWT_SECRET")
|
||||
}
|
||||
|
||||
// IsDevelopment returns true if the environment is development
|
||||
func (c *Config) IsDevelopment() bool {
|
||||
env := c.GetString("APP_ENV")
|
||||
|
||||
@ -197,6 +197,11 @@ func NewPermissionError(message string) *AppError {
|
||||
return New(ErrInsufficientPermissions, message)
|
||||
}
|
||||
|
||||
// NewAuthenticationError creates an authentication error
|
||||
func NewAuthenticationError(message string) *AppError {
|
||||
return New(ErrUnauthorized, message)
|
||||
}
|
||||
|
||||
// ErrorResponse represents the JSON error response format
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
|
||||
@ -3,23 +3,30 @@ package services
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// authenticationService implements the AuthenticationService interface
|
||||
type authenticationService struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
jwtManager *auth.JWTManager
|
||||
}
|
||||
|
||||
// NewAuthenticationService creates a new authentication service
|
||||
func NewAuthenticationService(config config.ConfigProvider, logger *zap.Logger) AuthenticationService {
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
return &authenticationService{
|
||||
config: config,
|
||||
logger: logger,
|
||||
jwtManager: jwtManager,
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,3 +70,78 @@ func (s *authenticationService) GetUserClaims(ctx context.Context, userID string
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// ValidateJWTToken validates a JWT token and returns claims
|
||||
func (s *authenticationService) ValidateJWTToken(ctx context.Context, tokenString string) (*domain.AuthContext, error) {
|
||||
s.logger.Debug("Validating JWT token")
|
||||
|
||||
// Validate the token using JWT manager
|
||||
claims, err := s.jwtManager.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
s.logger.Warn("JWT token validation failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
revoked, err := s.jwtManager.IsTokenRevoked(tokenString)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to check token revocation status", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to validate token").WithInternal(err)
|
||||
}
|
||||
|
||||
if revoked {
|
||||
s.logger.Warn("JWT token is revoked", zap.String("user_id", claims.UserID))
|
||||
return nil, errors.NewAuthenticationError("Token has been revoked")
|
||||
}
|
||||
|
||||
// Convert JWT claims to AuthContext
|
||||
authContext := &domain.AuthContext{
|
||||
UserID: claims.UserID,
|
||||
TokenType: claims.TokenType,
|
||||
Permissions: claims.Permissions,
|
||||
Claims: claims.Claims,
|
||||
AppID: claims.AppID,
|
||||
}
|
||||
|
||||
s.logger.Debug("JWT token validated successfully",
|
||||
zap.String("user_id", claims.UserID),
|
||||
zap.String("app_id", claims.AppID))
|
||||
|
||||
return authContext, nil
|
||||
}
|
||||
|
||||
// GenerateJWTToken generates a new JWT token for a user
|
||||
func (s *authenticationService) GenerateJWTToken(ctx context.Context, userToken *domain.UserToken) (string, error) {
|
||||
s.logger.Debug("Generating JWT token",
|
||||
zap.String("user_id", userToken.UserID),
|
||||
zap.String("app_id", userToken.AppID))
|
||||
|
||||
// Generate the token using JWT manager
|
||||
tokenString, err := s.jwtManager.GenerateToken(userToken)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to generate JWT token", zap.Error(err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Debug("JWT token generated successfully",
|
||||
zap.String("user_id", userToken.UserID),
|
||||
zap.String("app_id", userToken.AppID))
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// RefreshJWTToken refreshes an existing JWT token
|
||||
func (s *authenticationService) RefreshJWTToken(ctx context.Context, tokenString string, newExpiration time.Time) (string, error) {
|
||||
s.logger.Debug("Refreshing JWT token")
|
||||
|
||||
// Refresh the token using JWT manager
|
||||
newTokenString, err := s.jwtManager.RefreshToken(tokenString, newExpiration)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to refresh JWT token", zap.Error(err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Debug("JWT token refreshed successfully")
|
||||
|
||||
return newTokenString, nil
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
@ -56,4 +57,13 @@ type AuthenticationService interface {
|
||||
|
||||
// GetUserClaims retrieves user claims
|
||||
GetUserClaims(ctx context.Context, userID string) (map[string]string, error)
|
||||
|
||||
// ValidateJWTToken validates a JWT token and returns claims
|
||||
ValidateJWTToken(ctx context.Context, tokenString string) (*domain.AuthContext, error)
|
||||
|
||||
// GenerateJWTToken generates a new JWT token for a user
|
||||
GenerateJWTToken(ctx context.Context, userToken *domain.UserToken) (string, error)
|
||||
|
||||
// RefreshJWTToken refreshes an existing JWT token
|
||||
RefreshJWTToken(ctx context.Context, tokenString string, newExpiration time.Time) (string, error)
|
||||
}
|
||||
|
||||
380
test/auth_test.go
Normal file
380
test/auth_test.go
Normal file
@ -0,0 +1,380 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
)
|
||||
|
||||
// MockConfig implements ConfigProvider for testing
|
||||
type MockConfig struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func NewMockConfig() *MockConfig {
|
||||
return &MockConfig{
|
||||
values: map[string]string{
|
||||
"JWT_SECRET": "test-jwt-secret-for-testing-only",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetString(key string) string {
|
||||
return m.values[key]
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetInt(key string) int { return 0 }
|
||||
func (m *MockConfig) GetBool(key string) bool { return false }
|
||||
func (m *MockConfig) GetDuration(key string) time.Duration { return 0 }
|
||||
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
|
||||
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
|
||||
func (m *MockConfig) Validate() error { return nil }
|
||||
func (m *MockConfig) GetDatabaseDSN() string { return "" }
|
||||
func (m *MockConfig) GetServerAddress() string { return "" }
|
||||
func (m *MockConfig) GetMetricsAddress() string { return "" }
|
||||
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
|
||||
func (m *MockConfig) IsDevelopment() bool { return true }
|
||||
func (m *MockConfig) IsProduction() bool { return false }
|
||||
|
||||
func TestJWTManager_GenerateToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenString)
|
||||
|
||||
// Verify the token can be validated
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
assert.Equal(t, userToken.Permissions, claims.Permissions)
|
||||
assert.Equal(t, userToken.TokenType, claims.TokenType)
|
||||
assert.Equal(t, userToken.Claims, claims.Claims)
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test valid token
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
|
||||
// Test invalid token
|
||||
_, err = jwtManager.ValidateToken("invalid-token")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test empty token
|
||||
_, err = jwtManager.ValidateToken("")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_ExpiredToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
// Create an expired token
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired 1 hour ago
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validation should fail for expired token
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_MaxValidAtExpired(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
// Create a token that's past max valid time
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(-time.Hour), // Max valid time expired
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validation should fail for token past max valid time
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_RefreshToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
originalToken, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh the token
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
refreshedToken, err := jwtManager.RefreshToken(originalToken, newExpiration)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshedToken)
|
||||
assert.NotEqual(t, originalToken, refreshedToken)
|
||||
|
||||
// Validate the refreshed token
|
||||
claims, err := jwtManager.ValidateToken(refreshedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_ExtractClaims(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired token
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract claims from expired token (should work)
|
||||
claims, err := jwtManager.ExtractClaims(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_GetTokenInfo(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
info := jwtManager.GetTokenInfo(tokenString)
|
||||
assert.Equal(t, userToken.UserID, info["user_id"])
|
||||
assert.Equal(t, userToken.AppID, info["app_id"])
|
||||
assert.Equal(t, userToken.Permissions, info["permissions"])
|
||||
assert.Equal(t, userToken.TokenType, info["token_type"])
|
||||
}
|
||||
|
||||
func TestAuthenticationService_ValidateJWTToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
authService := services.NewAuthenticationService(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"email": "test@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := authService.GenerateJWTToken(context.Background(), userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token
|
||||
authContext, err := authService.ValidateJWTToken(context.Background(), tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, authContext.UserID)
|
||||
assert.Equal(t, userToken.AppID, authContext.AppID)
|
||||
assert.Equal(t, userToken.Permissions, authContext.Permissions)
|
||||
assert.Equal(t, userToken.TokenType, authContext.TokenType)
|
||||
assert.Equal(t, userToken.Claims, authContext.Claims)
|
||||
}
|
||||
|
||||
func TestAuthenticationService_GenerateJWTToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
authService := services.NewAuthenticationService(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := authService.GenerateJWTToken(context.Background(), userToken)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenString)
|
||||
|
||||
// Verify token can be validated
|
||||
authContext, err := authService.ValidateJWTToken(context.Background(), tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, authContext.UserID)
|
||||
}
|
||||
|
||||
func TestAuthenticationService_RefreshJWTToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
authService := services.NewAuthenticationService(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
originalToken, err := authService.GenerateJWTToken(context.Background(), userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh token
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
refreshedToken, err := authService.RefreshJWTToken(context.Background(), originalToken, newExpiration)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshedToken)
|
||||
assert.NotEqual(t, originalToken, refreshedToken)
|
||||
|
||||
// Validate refreshed token
|
||||
authContext, err := authService.ValidateJWTToken(context.Background(), refreshedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, authContext.UserID)
|
||||
}
|
||||
|
||||
func TestJWTManager_InvalidSecret(t *testing.T) {
|
||||
// Test with empty JWT secret
|
||||
config := &MockConfig{values: map[string]string{"JWT_SECRET": ""}}
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
_, err := jwtManager.GenerateToken(userToken)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_TokenRevocation(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check revocation status (should be false initially)
|
||||
revoked, err := jwtManager.IsTokenRevoked(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, revoked)
|
||||
|
||||
// Revoke token (currently just logs, doesn't actually revoke)
|
||||
err = jwtManager.RevokeToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: Current implementation doesn't actually implement blacklisting,
|
||||
// so this test just verifies the methods don't error
|
||||
}
|
||||
453
test/cache_test.go
Normal file
453
test/cache_test.go
Normal file
@ -0,0 +1,453 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
)
|
||||
|
||||
// MockConfig implements ConfigProvider for testing
|
||||
type MockConfig struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func NewMockConfig() *MockConfig {
|
||||
return &MockConfig{
|
||||
values: map[string]string{
|
||||
"CACHE_ENABLED": "true",
|
||||
"CACHE_TTL": "1h",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetString(key string) string {
|
||||
return m.values[key]
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetInt(key string) int { return 0 }
|
||||
func (m *MockConfig) GetBool(key string) bool {
|
||||
if key == "CACHE_ENABLED" {
|
||||
return m.values[key] == "true"
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (m *MockConfig) GetDuration(key string) time.Duration {
|
||||
if key == "CACHE_TTL" {
|
||||
if d, err := time.ParseDuration(m.values[key]); err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
|
||||
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
|
||||
func (m *MockConfig) Validate() error { return nil }
|
||||
func (m *MockConfig) GetDatabaseDSN() string { return "" }
|
||||
func (m *MockConfig) GetServerAddress() string { return "" }
|
||||
func (m *MockConfig) GetMetricsAddress() string { return "" }
|
||||
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
|
||||
func (m *MockConfig) IsDevelopment() bool { return true }
|
||||
func (m *MockConfig) IsProduction() bool { return false }
|
||||
|
||||
func TestMemoryCache_SetAndGet(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "test-key"
|
||||
value := []byte("test-value")
|
||||
ttl := time.Hour
|
||||
|
||||
// Set value
|
||||
err := memCache.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get value
|
||||
retrieved, err := memCache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
func TestMemoryCache_GetNonExistent(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "non-existent-key"
|
||||
|
||||
// Try to get non-existent key
|
||||
_, err := memCache.Get(ctx, key)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMemoryCache_Expiration(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "expiring-key"
|
||||
value := []byte("expiring-value")
|
||||
ttl := 100 * time.Millisecond
|
||||
|
||||
// Set value with short TTL
|
||||
err := memCache.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get value immediately (should work)
|
||||
retrieved, err := memCache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, retrieved)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Try to get expired value
|
||||
_, err = memCache.Get(ctx, key)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMemoryCache_Delete(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
ttl := time.Hour
|
||||
|
||||
// Set value
|
||||
err := memCache.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it exists
|
||||
exists, err := memCache.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete value
|
||||
err = memCache.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it no longer exists
|
||||
exists, err = memCache.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestMemoryCache_Exists(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
ttl := time.Hour
|
||||
|
||||
// Check non-existent key
|
||||
exists, err := memCache.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Set value
|
||||
err = memCache.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check existing key
|
||||
exists, err = memCache.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestMemoryCache_Clear(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
ttl := time.Hour
|
||||
|
||||
// Set multiple values
|
||||
err := memCache.Set(ctx, "key1", []byte("value1"), ttl)
|
||||
require.NoError(t, err)
|
||||
err = memCache.Set(ctx, "key2", []byte("value2"), ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify they exist
|
||||
exists, err := memCache.Exists(ctx, "key1")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
exists, err = memCache.Exists(ctx, "key2")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Clear cache
|
||||
err = memCache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify they no longer exist
|
||||
exists, err = memCache.Exists(ctx, "key1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
exists, err = memCache.Exists(ctx, "key2")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestCacheManager_SetAndGetJSON(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "json-key"
|
||||
ttl := time.Hour
|
||||
|
||||
// Test data
|
||||
originalData := map[string]interface{}{
|
||||
"name": "test",
|
||||
"value": 42,
|
||||
"items": []string{"a", "b", "c"},
|
||||
}
|
||||
|
||||
// Set JSON
|
||||
err := cacheManager.SetJSON(ctx, key, originalData, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get JSON
|
||||
var retrievedData map[string]interface{}
|
||||
err = cacheManager.GetJSON(ctx, key, &retrievedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare data
|
||||
assert.Equal(t, originalData["name"], retrievedData["name"])
|
||||
assert.Equal(t, float64(42), retrievedData["value"]) // JSON numbers are float64
|
||||
|
||||
// JSON arrays become []interface{}, so we need to compare differently
|
||||
retrievedItems := retrievedData["items"].([]interface{})
|
||||
expectedItems := []interface{}{"a", "b", "c"}
|
||||
assert.Equal(t, expectedItems, retrievedItems)
|
||||
}
|
||||
|
||||
func TestCacheManager_GetJSONNonExistent(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "non-existent-json-key"
|
||||
|
||||
var data map[string]interface{}
|
||||
err := cacheManager.GetJSON(ctx, key, &data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCacheManager_RawBytesOperations(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "raw-key"
|
||||
value := []byte("raw-value")
|
||||
ttl := time.Hour
|
||||
|
||||
// Set raw bytes
|
||||
err := cacheManager.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get raw bytes
|
||||
retrieved, err := cacheManager.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, retrieved)
|
||||
|
||||
// Check exists
|
||||
exists, err := cacheManager.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete
|
||||
err = cacheManager.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deleted
|
||||
exists, err = cacheManager.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestCacheKey(t *testing.T) {
|
||||
prefix := "test"
|
||||
key := "key123"
|
||||
expected := "test:key123"
|
||||
|
||||
result := cache.CacheKey(prefix, key)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCacheKeyPrefixes(t *testing.T) {
|
||||
// Test that constants are defined
|
||||
assert.Equal(t, "perm", cache.KeyPrefixPermission)
|
||||
assert.Equal(t, "app", cache.KeyPrefixApplication)
|
||||
assert.Equal(t, "token", cache.KeyPrefixToken)
|
||||
assert.Equal(t, "user_claims", cache.KeyPrefixUserClaims)
|
||||
assert.Equal(t, "token_revoked", cache.KeyPrefixTokenRevoked)
|
||||
}
|
||||
|
||||
func TestCacheManager_ConfigMethods(t *testing.T) {
|
||||
// Create mock config with cache settings
|
||||
config := &MockConfig{
|
||||
values: map[string]string{
|
||||
"CACHE_ENABLED": "true",
|
||||
"CACHE_TTL": "1h",
|
||||
},
|
||||
}
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
// Test IsEnabled
|
||||
assert.True(t, cacheManager.IsEnabled())
|
||||
|
||||
// Test GetDefaultTTL
|
||||
ttl := cacheManager.GetDefaultTTL()
|
||||
assert.Equal(t, time.Hour, ttl)
|
||||
}
|
||||
|
||||
func TestCacheManager_InvalidJSON(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "invalid-json-key"
|
||||
|
||||
// Set invalid JSON data manually
|
||||
invalidJSON := []byte("{invalid json}")
|
||||
err := cacheManager.Set(ctx, key, invalidJSON, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get as JSON (should fail)
|
||||
var data map[string]interface{}
|
||||
err = cacheManager.GetJSON(ctx, key, &data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCacheManager_SetJSONMarshalError(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "marshal-error-key"
|
||||
|
||||
// Try to set data that can't be marshaled (function)
|
||||
invalidData := func() {}
|
||||
err := cacheManager.SetJSON(ctx, key, invalidData, time.Hour)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMemoryCache_Set(b *testing.B) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
value := []byte("benchmark-value")
|
||||
ttl := time.Hour
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "benchmark-key-" + string(rune(i))
|
||||
memCache.Set(ctx, key, value, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMemoryCache_Get(b *testing.B) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
memCache := cache.NewMemoryCache(config, logger)
|
||||
defer memCache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "benchmark-get-key"
|
||||
value := []byte("benchmark-value")
|
||||
ttl := time.Hour
|
||||
|
||||
// Pre-populate cache
|
||||
memCache.Set(ctx, key, value, ttl)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
memCache.Get(ctx, key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheManager_SetJSON(b *testing.B) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
data := map[string]interface{}{
|
||||
"name": "benchmark",
|
||||
"value": 42,
|
||||
"items": []string{"a", "b", "c"},
|
||||
}
|
||||
ttl := time.Hour
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := "benchmark-json-key-" + string(rune(i))
|
||||
cacheManager.SetJSON(ctx, key, data, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheManager_GetJSON(b *testing.B) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "benchmark-json-get-key"
|
||||
data := map[string]interface{}{
|
||||
"name": "benchmark",
|
||||
"value": 42,
|
||||
"items": []string{"a", "b", "c"},
|
||||
}
|
||||
ttl := time.Hour
|
||||
|
||||
// Pre-populate cache
|
||||
cacheManager.SetJSON(ctx, key, data, ttl)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var retrieved map[string]interface{}
|
||||
cacheManager.GetJSON(ctx, key, &retrieved)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user