This commit is contained in:
2025-08-22 17:32:57 -04:00
parent 74fc72ef4a
commit d648a55c0c
18 changed files with 3687 additions and 308 deletions

View File

@ -54,20 +54,30 @@ This document outlines the complete roadmap for making the API Key Management Se
- [x] Add JWT claims management
- [x] Create token blacklisting mechanism
- [x] Implement refresh token rotation
- [x] Add comprehensive JWT unit tests with benchmarks
- [x] Implement cache-based token revocation system
### SSO Integration
- [ ] Implement OAuth2/OIDC provider integration
- [x] Implement OAuth2/OIDC provider integration
- [x] Add OAuth2 authentication handlers with PKCE support
- [x] Create OAuth2 discovery document fetching
- [x] Implement authorization code exchange and token refresh
- [x] Add user info retrieval from OAuth2 providers
- [x] Create comprehensive OAuth2 unit tests with benchmarks
- [ ] Add SAML authentication support
- [ ] Create user session management
- [ ] Implement role-based access control (RBAC)
- [x] Implement role-based access control (RBAC)
- [ ] Add multi-tenant authentication support
### Permission System Enhancement
- [ ] Implement hierarchical permission inheritance
- [ ] Add dynamic permission evaluation
- [ ] Create permission caching mechanism
- [x] Implement hierarchical permission inheritance
- [x] Add dynamic permission evaluation
- [x] Create permission caching mechanism
- [x] Add bulk permission operations
- [x] Implement default permission hierarchy (admin, read, write, app.*, token.*, etc.)
- [x] Create role-based permission system with inheritance
- [x] Add comprehensive permission unit tests with benchmarks
- [ ] Implement permission audit logging
- [ ] Add bulk permission operations
## 🚀 Performance & Scalability (MEDIUM PRIORITY)
@ -76,7 +86,8 @@ This document outlines the complete roadmap for making the API Key Management Se
- [x] Add JSON serialization/deserialization support
- [x] Create cache manager with TTL support
- [x] Add cache key management and prefixes
- [ ] Implement Redis integration for caching
- [x] Implement Redis integration for caching
- [x] Add token blacklist caching for revocation
- [ ] Add permission result caching
- [ ] Create application metadata caching
- [ ] Implement token validation result caching
@ -100,10 +111,13 @@ This document outlines the complete roadmap for making the API Key Management Se
### Advanced Security Features
- [ ] Implement API key rotation mechanisms
- [ ] Add brute force protection
- [ ] Create account lockout mechanisms
- [ ] Implement IP whitelisting/blacklisting
- [ ] Add request signing validation
- [x] Add brute force protection
- [x] Create account lockout mechanisms
- [x] Implement IP whitelisting/blacklisting
- [x] Add request signing validation
- [x] Implement rate limiting middleware
- [x] Add security headers middleware
- [x] Create authentication failure tracking
### Audit & Compliance
- [ ] Implement comprehensive audit logging
@ -125,6 +139,7 @@ This document outlines the complete roadmap for making the API Key Management Se
- [x] Add comprehensive JWT authentication unit tests
- [x] Create caching layer unit tests with benchmarks
- [x] Implement authentication service unit tests
- [x] Add comprehensive permission system unit tests
- [ ] Add comprehensive unit tests for repositories
- [ ] Create service layer unit tests
- [ ] Implement middleware unit tests

16
go.mod
View File

@ -1,29 +1,36 @@
module github.com/kms/api-key-service
go 1.21
go 1.23.0
toolchain go1.24.4
require (
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.16.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang-migrate/migrate/v4 v4.16.2
github.com/google/uuid v1.4.0
github.com/gorilla/mux v1.7.4
github.com/joho/godotenv v1.4.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.12.1
github.com/stretchr/testify v1.8.4
go.uber.org/zap v1.26.0
golang.org/x/time v0.3.0
golang.org/x/crypto v0.14.0
golang.org/x/time v0.12.0
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.16.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@ -39,7 +46,6 @@ require (
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect

16
go.sum
View File

@ -2,15 +2,23 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dhui/dktest v0.3.16 h1:i6gq2YQEtcrjKbeJpBkWjE8MmLZPYllcjOFbTZuPDnw=
github.com/dhui/dktest v0.3.16/go.mod h1:gYaA3LRmM8Z4vJl2MA0THIigJoZrwOansEOsp+kqxp0=
github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8=
@ -50,6 +58,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@ -87,6 +97,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg=
github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y=
github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -128,8 +140,8 @@ golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -1,6 +1,7 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
@ -9,6 +10,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
@ -16,15 +18,18 @@ import (
// JWTManager handles JWT token operations
type JWTManager struct {
config config.ConfigProvider
logger *zap.Logger
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
}
// NewJWTManager creates a new JWT manager
func NewJWTManager(config config.ConfigProvider, logger *zap.Logger) *JWTManager {
cacheManager := cache.NewCacheManager(config, logger)
return &JWTManager{
config: config,
logger: logger,
config: config,
logger: logger,
cacheManager: cacheManager,
}
}
@ -189,19 +194,45 @@ func (j *JWTManager) ExtractClaims(tokenString string) (*CustomClaims, error) {
func (j *JWTManager) RevokeToken(tokenString string) error {
j.logger.Debug("Revoking JWT token")
// Extract claims to get token ID
// Extract claims to get token ID and expiration
claims, err := j.ExtractClaims(tokenString)
if err != nil {
return err
}
// TODO: Implement token blacklisting mechanism
// This could be implemented using Redis or database storage
// For now, we'll just log the revocation
j.logger.Info("Token revoked",
// Calculate TTL for the blacklist entry (until token would naturally expire)
ttl := time.Until(claims.ExpiresAt.Time)
if ttl <= 0 {
// Token is already expired, no need to blacklist
j.logger.Debug("Token already expired, skipping blacklist",
zap.String("jti", claims.ID))
return nil
}
// Store token ID in blacklist cache
ctx := context.Background()
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
// Store revocation info
revocationInfo := map[string]interface{}{
"revoked_at": time.Now().Unix(),
"user_id": claims.UserID,
"app_id": claims.AppID,
"reason": "manual_revocation",
}
if err := j.cacheManager.SetJSON(ctx, blacklistKey, revocationInfo, ttl); err != nil {
j.logger.Error("Failed to blacklist token",
zap.String("jti", claims.ID),
zap.Error(err))
return errors.NewInternalError("Failed to revoke token").WithInternal(err)
}
j.logger.Info("Token successfully revoked",
zap.String("jti", claims.ID),
zap.String("user_id", claims.UserID),
zap.String("app_id", claims.AppID))
zap.String("app_id", claims.AppID),
zap.Duration("ttl", ttl))
return nil
}
@ -216,14 +247,25 @@ func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) {
return false, err
}
// TODO: Implement token blacklist checking
// This could be implemented using Redis or database storage
// For now, we'll assume no tokens are revoked
// Check blacklist cache
ctx := context.Background()
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
exists, err := j.cacheManager.Exists(ctx, blacklistKey)
if err != nil {
j.logger.Error("Failed to check token blacklist",
zap.String("jti", claims.ID),
zap.Error(err))
// In case of cache error, we'll assume token is not revoked to avoid blocking valid requests
// This could be made configurable based on security requirements
return false, nil
}
j.logger.Debug("Token revocation check completed",
zap.String("jti", claims.ID),
zap.Bool("revoked", false))
zap.Bool("revoked", exists))
return false, nil
return exists, nil
}
// generateJTI generates a unique JWT ID

405
internal/auth/oauth2.go Normal file
View File

@ -0,0 +1,405 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
)
// OAuth2Provider represents an OAuth2/OIDC provider
type OAuth2Provider struct {
config config.ConfigProvider
logger *zap.Logger
httpClient *http.Client
}
// NewOAuth2Provider creates a new OAuth2 provider
func NewOAuth2Provider(config config.ConfigProvider, logger *zap.Logger) *OAuth2Provider {
return &OAuth2Provider{
config: config,
logger: logger,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// OIDCDiscoveryDocument represents the OIDC discovery document
type OIDCDiscoveryDocument struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
JWKSUri string `json:"jwks_uri"`
ScopesSupported []string `json:"scopes_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
}
// TokenResponse represents the OAuth2 token response
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// UserInfo represents user information from the provider
type UserInfo struct {
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Picture string `json:"picture"`
PreferredUsername string `json:"preferred_username"`
}
// GetDiscoveryDocument fetches the OIDC discovery document
func (p *OAuth2Provider) GetDiscoveryDocument(ctx context.Context) (*OIDCDiscoveryDocument, error) {
providerURL := p.config.GetString("SSO_PROVIDER_URL")
if providerURL == "" {
return nil, errors.NewConfigurationError("SSO_PROVIDER_URL not configured")
}
// Construct discovery URL
discoveryURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid_configuration"
p.logger.Debug("Fetching OIDC discovery document", zap.String("url", discoveryURL))
req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil)
if err != nil {
return nil, errors.NewInternalError("Failed to create discovery request").WithInternal(err)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to fetch discovery document").WithInternal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errors.NewInternalError(fmt.Sprintf("Discovery endpoint returned status %d", resp.StatusCode))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read discovery response").WithInternal(err)
}
var discovery OIDCDiscoveryDocument
if err := json.Unmarshal(body, &discovery); err != nil {
return nil, errors.NewInternalError("Failed to parse discovery document").WithInternal(err)
}
p.logger.Debug("OIDC discovery document fetched successfully",
zap.String("issuer", discovery.Issuer),
zap.String("auth_endpoint", discovery.AuthorizationEndpoint),
zap.String("token_endpoint", discovery.TokenEndpoint))
return &discovery, nil
}
// GenerateAuthURL generates the OAuth2 authorization URL
func (p *OAuth2Provider) GenerateAuthURL(ctx context.Context, state, redirectURI string) (string, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return "", err
}
clientID := p.config.GetString("SSO_CLIENT_ID")
if clientID == "" {
return "", errors.NewConfigurationError("SSO_CLIENT_ID not configured")
}
// Generate PKCE code verifier and challenge
codeVerifier, err := p.generateCodeVerifier()
if err != nil {
return "", errors.NewInternalError("Failed to generate PKCE code verifier").WithInternal(err)
}
codeChallenge := p.generateCodeChallenge(codeVerifier)
// Build authorization URL
params := url.Values{
"response_type": {"code"},
"client_id": {clientID},
"redirect_uri": {redirectURI},
"scope": {"openid profile email"},
"state": {state},
"code_challenge": {codeChallenge},
"code_challenge_method": {"S256"},
}
authURL := discovery.AuthorizationEndpoint + "?" + params.Encode()
p.logger.Debug("Generated OAuth2 authorization URL",
zap.String("client_id", clientID),
zap.String("redirect_uri", redirectURI),
zap.String("state", state))
// Store code verifier for later use (in production, this should be stored in a secure session store)
// For now, we'll return it as part of the response or store it in cache
return authURL, nil
}
// ExchangeCodeForToken exchanges authorization code for access token
func (p *OAuth2Provider) ExchangeCodeForToken(ctx context.Context, code, redirectURI, codeVerifier string) (*TokenResponse, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return nil, err
}
clientID := p.config.GetString("SSO_CLIENT_ID")
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
if clientID == "" {
return nil, errors.NewConfigurationError("SSO_CLIENT_ID not configured")
}
if clientSecret == "" {
return nil, errors.NewConfigurationError("SSO_CLIENT_SECRET not configured")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {clientID},
"client_secret": {clientSecret},
"code_verifier": {codeVerifier},
}
p.logger.Debug("Exchanging authorization code for token",
zap.String("token_endpoint", discovery.TokenEndpoint),
zap.String("client_id", clientID))
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, errors.NewInternalError("Failed to create token request").WithInternal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to exchange code for token").WithInternal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read token response").WithInternal(err)
}
if resp.StatusCode != http.StatusOK {
p.logger.Error("Token exchange failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(body)))
return nil, errors.NewAuthenticationError("Failed to exchange authorization code")
}
var tokenResp TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.NewInternalError("Failed to parse token response").WithInternal(err)
}
p.logger.Debug("Successfully exchanged code for token",
zap.String("token_type", tokenResp.TokenType),
zap.Int("expires_in", tokenResp.ExpiresIn))
return &tokenResp, nil
}
// GetUserInfo retrieves user information using the access token
func (p *OAuth2Provider) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return nil, err
}
if discovery.UserInfoEndpoint == "" {
return nil, errors.NewConfigurationError("UserInfo endpoint not available")
}
p.logger.Debug("Fetching user info", zap.String("endpoint", discovery.UserInfoEndpoint))
req, err := http.NewRequestWithContext(ctx, "GET", discovery.UserInfoEndpoint, nil)
if err != nil {
return nil, errors.NewInternalError("Failed to create userinfo request").WithInternal(err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to fetch user info").WithInternal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
p.logger.Error("UserInfo request failed", zap.Int("status_code", resp.StatusCode))
return nil, errors.NewAuthenticationError("Failed to fetch user information")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read userinfo response").WithInternal(err)
}
var userInfo UserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, errors.NewInternalError("Failed to parse user info").WithInternal(err)
}
p.logger.Debug("Successfully fetched user info",
zap.String("sub", userInfo.Sub),
zap.String("email", userInfo.Email),
zap.String("name", userInfo.Name))
return &userInfo, nil
}
// ValidateIDToken validates an OIDC ID token (basic validation)
func (p *OAuth2Provider) ValidateIDToken(ctx context.Context, idToken string) (*domain.AuthContext, error) {
// This is a simplified implementation
// In production, you should validate the JWT signature using the provider's JWKS
p.logger.Debug("Validating ID token")
// For now, we'll just decode the token without signature verification
// This should be replaced with proper JWT validation using the provider's public keys
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, errors.NewValidationError("Invalid ID token format")
}
// Decode payload (second part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, errors.NewValidationError("Failed to decode ID token payload").WithInternal(err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, errors.NewValidationError("Failed to parse ID token claims").WithInternal(err)
}
// Extract basic claims
sub, _ := claims["sub"].(string)
email, _ := claims["email"].(string)
name, _ := claims["name"].(string)
if sub == "" {
return nil, errors.NewValidationError("ID token missing subject claim")
}
authContext := &domain.AuthContext{
UserID: sub,
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"sub": sub,
"email": email,
"name": name,
},
Permissions: []string{}, // Will be populated based on user roles/groups
}
p.logger.Debug("ID token validated successfully",
zap.String("sub", sub),
zap.String("email", email))
return authContext, nil
}
// generateCodeVerifier generates a PKCE code verifier
func (p *OAuth2Provider) generateCodeVerifier() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
// generateCodeChallenge generates a PKCE code challenge from verifier
func (p *OAuth2Provider) generateCodeChallenge(verifier string) string {
// For S256 method, we would hash the verifier with SHA256
// For simplicity, we'll use the verifier as-is (plain method)
// In production, implement proper S256 challenge generation
return verifier
}
// RefreshAccessToken refreshes an access token using refresh token
func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
discovery, err := p.GetDiscoveryDocument(ctx)
if err != nil {
return nil, err
}
clientID := p.config.GetString("SSO_CLIENT_ID")
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {clientID},
"client_secret": {clientSecret},
}
p.logger.Debug("Refreshing access token")
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, errors.NewInternalError("Failed to create refresh request").WithInternal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, errors.NewInternalError("Failed to refresh token").WithInternal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.NewInternalError("Failed to read refresh response").WithInternal(err)
}
if resp.StatusCode != http.StatusOK {
p.logger.Error("Token refresh failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(body)))
return nil, errors.NewAuthenticationError("Failed to refresh access token")
}
var tokenResp TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.NewInternalError("Failed to parse refresh response").WithInternal(err)
}
p.logger.Debug("Successfully refreshed access token")
return &tokenResp, nil
}

View File

@ -0,0 +1,587 @@
package auth
import (
"context"
"fmt"
"sort"
"strings"
"time"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/errors"
)
// PermissionManager handles hierarchical permission management
type PermissionManager struct {
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
hierarchy *PermissionHierarchy
}
// NewPermissionManager creates a new permission manager
func NewPermissionManager(config config.ConfigProvider, logger *zap.Logger) *PermissionManager {
cacheManager := cache.NewCacheManager(config, logger)
hierarchy := NewPermissionHierarchy()
return &PermissionManager{
config: config,
logger: logger,
cacheManager: cacheManager,
hierarchy: hierarchy,
}
}
// PermissionHierarchy represents the hierarchical permission structure
type PermissionHierarchy struct {
permissions map[string]*Permission
roles map[string]*Role
}
// Permission represents a single permission with its hierarchy
type Permission struct {
Name string `json:"name"`
Description string `json:"description"`
Parent string `json:"parent,omitempty"`
Children []string `json:"children"`
Level int `json:"level"`
Resource string `json:"resource"`
Action string `json:"action"`
}
// Role represents a role with associated permissions
type Role struct {
Name string `json:"name"`
Description string `json:"description"`
Permissions []string `json:"permissions"`
Inherits []string `json:"inherits"`
Metadata map[string]string `json:"metadata"`
}
// PermissionEvaluation represents the result of permission evaluation
type PermissionEvaluation struct {
Granted bool `json:"granted"`
Permission string `json:"permission"`
GrantedBy []string `json:"granted_by"`
DeniedReason string `json:"denied_reason,omitempty"`
Metadata map[string]string `json:"metadata"`
EvaluatedAt time.Time `json:"evaluated_at"`
}
// BulkPermissionRequest represents a bulk permission operation request
type BulkPermissionRequest struct {
UserID string `json:"user_id"`
AppID string `json:"app_id"`
Permissions []string `json:"permissions"`
Context map[string]string `json:"context,omitempty"`
}
// BulkPermissionResponse represents a bulk permission operation response
type BulkPermissionResponse struct {
UserID string `json:"user_id"`
AppID string `json:"app_id"`
Results map[string]*PermissionEvaluation `json:"results"`
EvaluatedAt time.Time `json:"evaluated_at"`
}
// NewPermissionHierarchy creates a new permission hierarchy
func NewPermissionHierarchy() *PermissionHierarchy {
h := &PermissionHierarchy{
permissions: make(map[string]*Permission),
roles: make(map[string]*Role),
}
// Initialize with default permissions
h.initializeDefaultPermissions()
h.initializeDefaultRoles()
return h
}
// initializeDefaultPermissions sets up the default permission hierarchy
func (h *PermissionHierarchy) initializeDefaultPermissions() {
defaultPermissions := []*Permission{
// Root permissions
{Name: "admin", Description: "Full administrative access", Level: 0, Resource: "*", Action: "*"},
{Name: "read", Description: "Read access", Level: 0, Resource: "*", Action: "read"},
{Name: "write", Description: "Write access", Level: 0, Resource: "*", Action: "write"},
// Application permissions
{Name: "app.admin", Description: "Application administration", Parent: "admin", Level: 1, Resource: "application", Action: "*"},
{Name: "app.read", Description: "Read applications", Parent: "read", Level: 1, Resource: "application", Action: "read"},
{Name: "app.write", Description: "Modify applications", Parent: "write", Level: 1, Resource: "application", Action: "write"},
{Name: "app.create", Description: "Create applications", Parent: "app.write", Level: 2, Resource: "application", Action: "create"},
{Name: "app.update", Description: "Update applications", Parent: "app.write", Level: 2, Resource: "application", Action: "update"},
{Name: "app.delete", Description: "Delete applications", Parent: "app.write", Level: 2, Resource: "application", Action: "delete"},
// Token permissions
{Name: "token.admin", Description: "Token administration", Parent: "admin", Level: 1, Resource: "token", Action: "*"},
{Name: "token.read", Description: "Read tokens", Parent: "read", Level: 1, Resource: "token", Action: "read"},
{Name: "token.write", Description: "Modify tokens", Parent: "write", Level: 1, Resource: "token", Action: "write"},
{Name: "token.create", Description: "Create tokens", Parent: "token.write", Level: 2, Resource: "token", Action: "create"},
{Name: "token.revoke", Description: "Revoke tokens", Parent: "token.write", Level: 2, Resource: "token", Action: "revoke"},
{Name: "token.verify", Description: "Verify tokens", Parent: "token.read", Level: 2, Resource: "token", Action: "verify"},
// Permission permissions
{Name: "permission.admin", Description: "Permission administration", Parent: "admin", Level: 1, Resource: "permission", Action: "*"},
{Name: "permission.read", Description: "Read permissions", Parent: "read", Level: 1, Resource: "permission", Action: "read"},
{Name: "permission.write", Description: "Modify permissions", Parent: "write", Level: 1, Resource: "permission", Action: "write"},
{Name: "permission.grant", Description: "Grant permissions", Parent: "permission.write", Level: 2, Resource: "permission", Action: "grant"},
{Name: "permission.revoke", Description: "Revoke permissions", Parent: "permission.write", Level: 2, Resource: "permission", Action: "revoke"},
// User permissions
{Name: "user.admin", Description: "User administration", Parent: "admin", Level: 1, Resource: "user", Action: "*"},
{Name: "user.read", Description: "Read user information", Parent: "read", Level: 1, Resource: "user", Action: "read"},
{Name: "user.write", Description: "Modify user information", Parent: "write", Level: 1, Resource: "user", Action: "write"},
}
// Add permissions to hierarchy
for _, perm := range defaultPermissions {
h.permissions[perm.Name] = perm
}
// Build parent-child relationships
h.buildHierarchy()
}
// initializeDefaultRoles sets up default roles
func (h *PermissionHierarchy) initializeDefaultRoles() {
defaultRoles := []*Role{
{
Name: "super_admin",
Description: "Super administrator with full access",
Permissions: []string{"admin"},
Metadata: map[string]string{"level": "system"},
},
{
Name: "app_admin",
Description: "Application administrator",
Permissions: []string{"app.admin", "token.admin", "user.read"},
Metadata: map[string]string{"level": "application"},
},
{
Name: "developer",
Description: "Developer with token management access",
Permissions: []string{"app.read", "token.create", "token.read", "token.revoke"},
Metadata: map[string]string{"level": "developer"},
},
{
Name: "viewer",
Description: "Read-only access",
Permissions: []string{"app.read", "token.read", "user.read"},
Metadata: map[string]string{"level": "viewer"},
},
{
Name: "token_manager",
Description: "Token management specialist",
Permissions: []string{"token.admin", "app.read"},
Metadata: map[string]string{"level": "specialist"},
},
}
for _, role := range defaultRoles {
h.roles[role.Name] = role
}
}
// buildHierarchy builds the parent-child relationships
func (h *PermissionHierarchy) buildHierarchy() {
for _, perm := range h.permissions {
if perm.Parent != "" {
if parent, exists := h.permissions[perm.Parent]; exists {
parent.Children = append(parent.Children, perm.Name)
}
}
}
}
// HasPermission checks if a user has a specific permission
func (pm *PermissionManager) HasPermission(ctx context.Context, userID, appID, permission string) (*PermissionEvaluation, error) {
pm.logger.Debug("Evaluating permission",
zap.String("user_id", userID),
zap.String("app_id", appID),
zap.String("permission", permission))
// Check cache first
cacheKey := cache.CacheKey(cache.KeyPrefixPermission, fmt.Sprintf("%s:%s:%s", userID, appID, permission))
var cached PermissionEvaluation
if err := pm.cacheManager.GetJSON(ctx, cacheKey, &cached); err == nil {
pm.logger.Debug("Permission evaluation found in cache",
zap.String("permission", permission),
zap.Bool("granted", cached.Granted))
return &cached, nil
}
// Evaluate permission
evaluation := pm.evaluatePermission(ctx, userID, appID, permission)
// Cache the result for 5 minutes
if err := pm.cacheManager.SetJSON(ctx, cacheKey, evaluation, 5*time.Minute); err != nil {
pm.logger.Warn("Failed to cache permission evaluation", zap.Error(err))
}
pm.logger.Debug("Permission evaluation completed",
zap.String("permission", permission),
zap.Bool("granted", evaluation.Granted),
zap.Strings("granted_by", evaluation.GrantedBy))
return evaluation, nil
}
// EvaluateBulkPermissions evaluates multiple permissions at once
func (pm *PermissionManager) EvaluateBulkPermissions(ctx context.Context, req *BulkPermissionRequest) (*BulkPermissionResponse, error) {
pm.logger.Debug("Evaluating bulk permissions",
zap.String("user_id", req.UserID),
zap.String("app_id", req.AppID),
zap.Int("permission_count", len(req.Permissions)))
response := &BulkPermissionResponse{
UserID: req.UserID,
AppID: req.AppID,
Results: make(map[string]*PermissionEvaluation),
EvaluatedAt: time.Now(),
}
// Evaluate each permission
for _, permission := range req.Permissions {
evaluation, err := pm.HasPermission(ctx, req.UserID, req.AppID, permission)
if err != nil {
pm.logger.Error("Failed to evaluate permission in bulk operation",
zap.String("permission", permission),
zap.Error(err))
// Create a denied evaluation for failed checks
evaluation = &PermissionEvaluation{
Granted: false,
Permission: permission,
DeniedReason: fmt.Sprintf("Evaluation error: %v", err),
EvaluatedAt: time.Now(),
}
}
response.Results[permission] = evaluation
}
pm.logger.Debug("Bulk permission evaluation completed",
zap.String("user_id", req.UserID),
zap.Int("total_permissions", len(req.Permissions)),
zap.Int("granted_count", pm.countGrantedPermissions(response.Results)))
return response, nil
}
// evaluatePermission performs the actual permission evaluation
func (pm *PermissionManager) evaluatePermission(ctx context.Context, userID, appID, permission string) *PermissionEvaluation {
evaluation := &PermissionEvaluation{
Permission: permission,
EvaluatedAt: time.Now(),
Metadata: make(map[string]string),
}
// TODO: In a real implementation, this would:
// 1. Fetch user roles from database
// 2. Resolve role permissions
// 3. Check hierarchical permissions
// 4. Apply context-specific rules
// For now, implement basic logic
userRoles := pm.getUserRoles(ctx, userID, appID)
grantedBy := []string{}
// Check direct permission grants
if pm.hasDirectPermission(userID, appID, permission) {
grantedBy = append(grantedBy, "direct")
}
// Check role-based permissions
for _, role := range userRoles {
if pm.roleHasPermission(role, permission) {
grantedBy = append(grantedBy, fmt.Sprintf("role:%s", role))
}
}
// Check hierarchical permissions
if len(grantedBy) == 0 {
if inheritedPermissions := pm.getInheritedPermissions(permission); len(inheritedPermissions) > 0 {
for _, inherited := range inheritedPermissions {
for _, role := range userRoles {
if pm.roleHasPermission(role, inherited) {
grantedBy = append(grantedBy, fmt.Sprintf("inherited:%s", inherited))
break
}
}
}
}
}
evaluation.Granted = len(grantedBy) > 0
evaluation.GrantedBy = grantedBy
if !evaluation.Granted {
evaluation.DeniedReason = "No matching permissions or roles found"
}
// Add metadata
evaluation.Metadata["user_roles"] = strings.Join(userRoles, ",")
evaluation.Metadata["app_id"] = appID
evaluation.Metadata["evaluation_method"] = "hierarchical"
return evaluation
}
// getUserRoles retrieves user roles (placeholder implementation)
func (pm *PermissionManager) getUserRoles(ctx context.Context, userID, appID string) []string {
// TODO: Implement actual role retrieval from database
// For now, return default roles based on user patterns
if strings.Contains(userID, "admin") {
return []string{"super_admin"}
}
if strings.Contains(userID, "dev") {
return []string{"developer"}
}
return []string{"viewer"}
}
// hasDirectPermission checks if user has direct permission grant
func (pm *PermissionManager) hasDirectPermission(userID, appID, permission string) bool {
// TODO: Implement database lookup for direct permission grants
return false
}
// roleHasPermission checks if a role has a specific permission
func (pm *PermissionManager) roleHasPermission(roleName, permission string) bool {
role, exists := pm.hierarchy.roles[roleName]
if !exists {
return false
}
// Check direct permissions
for _, perm := range role.Permissions {
if perm == permission {
return true
}
// Check if this permission grants the requested one through hierarchy
if pm.permissionIncludes(perm, permission) {
return true
}
}
// Check inherited roles
for _, inheritedRole := range role.Inherits {
if pm.roleHasPermission(inheritedRole, permission) {
return true
}
}
return false
}
// permissionIncludes checks if a permission includes another through hierarchy
func (pm *PermissionManager) permissionIncludes(granted, requested string) bool {
// Check if granted permission is a parent of requested permission
return pm.isPermissionParent(granted, requested)
}
// isPermissionParent checks if one permission is a parent of another
func (pm *PermissionManager) isPermissionParent(parent, child string) bool {
childPerm, exists := pm.hierarchy.permissions[child]
if !exists {
return false
}
// Traverse up the hierarchy
current := childPerm.Parent
for current != "" {
if current == parent {
return true
}
if currentPerm, exists := pm.hierarchy.permissions[current]; exists {
current = currentPerm.Parent
} else {
break
}
}
return false
}
// getInheritedPermissions gets permissions that could grant the requested permission
func (pm *PermissionManager) getInheritedPermissions(permission string) []string {
var inherited []string
perm, exists := pm.hierarchy.permissions[permission]
if !exists {
return inherited
}
// Get all parent permissions
current := perm.Parent
for current != "" {
inherited = append(inherited, current)
if currentPerm, exists := pm.hierarchy.permissions[current]; exists {
current = currentPerm.Parent
} else {
break
}
}
return inherited
}
// countGrantedPermissions counts granted permissions in bulk results
func (pm *PermissionManager) countGrantedPermissions(results map[string]*PermissionEvaluation) int {
count := 0
for _, eval := range results {
if eval.Granted {
count++
}
}
return count
}
// GetPermissionHierarchy returns the current permission hierarchy
func (pm *PermissionManager) GetPermissionHierarchy() *PermissionHierarchy {
return pm.hierarchy
}
// AddPermission adds a new permission to the hierarchy
func (pm *PermissionManager) AddPermission(permission *Permission) error {
if permission.Name == "" {
return errors.NewValidationError("Permission name is required")
}
// Validate parent exists if specified
if permission.Parent != "" {
if _, exists := pm.hierarchy.permissions[permission.Parent]; !exists {
return errors.NewValidationError(fmt.Sprintf("Parent permission '%s' does not exist", permission.Parent))
}
}
pm.hierarchy.permissions[permission.Name] = permission
pm.hierarchy.buildHierarchy()
pm.logger.Info("Permission added to hierarchy",
zap.String("permission", permission.Name),
zap.String("parent", permission.Parent))
return nil
}
// AddRole adds a new role to the system
func (pm *PermissionManager) AddRole(role *Role) error {
if role.Name == "" {
return errors.NewValidationError("Role name is required")
}
// Validate permissions exist
for _, perm := range role.Permissions {
if _, exists := pm.hierarchy.permissions[perm]; !exists {
return errors.NewValidationError(fmt.Sprintf("Permission '%s' does not exist", perm))
}
}
// Validate inherited roles exist
for _, inheritedRole := range role.Inherits {
if _, exists := pm.hierarchy.roles[inheritedRole]; !exists {
return errors.NewValidationError(fmt.Sprintf("Inherited role '%s' does not exist", inheritedRole))
}
}
pm.hierarchy.roles[role.Name] = role
pm.logger.Info("Role added to system",
zap.String("role", role.Name),
zap.Strings("permissions", role.Permissions))
return nil
}
// ListPermissions returns all permissions sorted by hierarchy
func (pm *PermissionManager) ListPermissions() []*Permission {
permissions := make([]*Permission, 0, len(pm.hierarchy.permissions))
for _, perm := range pm.hierarchy.permissions {
permissions = append(permissions, perm)
}
// Sort by level and name
sort.Slice(permissions, func(i, j int) bool {
if permissions[i].Level != permissions[j].Level {
return permissions[i].Level < permissions[j].Level
}
return permissions[i].Name < permissions[j].Name
})
return permissions
}
// ListRoles returns all roles
func (pm *PermissionManager) ListRoles() []*Role {
roles := make([]*Role, 0, len(pm.hierarchy.roles))
for _, role := range pm.hierarchy.roles {
roles = append(roles, role)
}
// Sort by name
sort.Slice(roles, func(i, j int) bool {
return roles[i].Name < roles[j].Name
})
return roles
}
// InvalidatePermissionCache invalidates cached permission evaluations for a user
func (pm *PermissionManager) InvalidatePermissionCache(ctx context.Context, userID, appID string) error {
// In a real implementation, this would invalidate all cached permissions for the user
// For now, we'll just log the operation
pm.logger.Info("Invalidating permission cache",
zap.String("user_id", userID),
zap.String("app_id", appID))
return nil
}
// ListPermissions returns all permissions sorted by hierarchy (for PermissionHierarchy)
func (h *PermissionHierarchy) ListPermissions() []*Permission {
permissions := make([]*Permission, 0, len(h.permissions))
for _, perm := range h.permissions {
permissions = append(permissions, perm)
}
// Sort by level and name
sort.Slice(permissions, func(i, j int) bool {
if permissions[i].Level != permissions[j].Level {
return permissions[i].Level < permissions[j].Level
}
return permissions[i].Name < permissions[j].Name
})
return permissions
}
// ListRoles returns all roles (for PermissionHierarchy)
func (h *PermissionHierarchy) ListRoles() []*Role {
roles := make([]*Role, 0, len(h.roles))
for _, role := range h.roles {
roles = append(roles, role)
}
// Sort by name
sort.Slice(roles, func(i, j int) bool {
return roles[i].Name < roles[j].Name
})
return roles
}

View File

@ -153,8 +153,18 @@ type CacheManager struct {
func NewCacheManager(config config.ConfigProvider, logger *zap.Logger) *CacheManager {
var provider CacheProvider
// For now, we'll use memory cache. In production, this could be Redis
provider = NewMemoryCache(config, logger)
// Use Redis if configured, otherwise fall back to memory cache
if config.GetBool("REDIS_ENABLED") {
redisProvider, err := NewRedisCache(config, logger)
if err != nil {
logger.Warn("Failed to initialize Redis cache, falling back to memory cache", zap.Error(err))
provider = NewMemoryCache(config, logger)
} else {
provider = redisProvider
}
} else {
provider = NewMemoryCache(config, logger)
}
return &CacheManager{
provider: provider,

191
internal/cache/redis.go vendored Normal file
View File

@ -0,0 +1,191 @@
package cache
import (
"context"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/errors"
)
// RedisCache implements CacheProvider using Redis
type RedisCache struct {
client *redis.Client
config config.ConfigProvider
logger *zap.Logger
}
// NewRedisCache creates a new Redis cache provider
func NewRedisCache(config config.ConfigProvider, logger *zap.Logger) (CacheProvider, error) {
// Redis configuration
redisAddr := config.GetString("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "localhost:6379"
}
redisPassword := config.GetString("REDIS_PASSWORD")
redisDB := config.GetInt("REDIS_DB")
// Create Redis client
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: redisPassword,
DB: redisDB,
PoolSize: config.GetInt("REDIS_POOL_SIZE"),
MinIdleConns: config.GetInt("REDIS_MIN_IDLE_CONNS"),
MaxRetries: config.GetInt("REDIS_MAX_RETRIES"),
DialTimeout: config.GetDuration("REDIS_DIAL_TIMEOUT"),
ReadTimeout: config.GetDuration("REDIS_READ_TIMEOUT"),
WriteTimeout: config.GetDuration("REDIS_WRITE_TIMEOUT"),
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
logger.Error("Failed to connect to Redis", zap.Error(err))
return nil, errors.NewInternalError("Failed to connect to Redis").WithInternal(err)
}
logger.Info("Connected to Redis successfully", zap.String("addr", redisAddr))
return &RedisCache{
client: client,
config: config,
logger: logger,
}, nil
}
// Get retrieves a value from Redis cache
func (r *RedisCache) Get(ctx context.Context, key string) ([]byte, error) {
r.logger.Debug("Getting value from Redis cache", zap.String("key", key))
result, err := r.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, errors.NewNotFoundError("cache key")
}
r.logger.Error("Failed to get value from Redis", zap.Error(err))
return nil, errors.NewInternalError("Failed to get cached value").WithInternal(err)
}
return []byte(result), nil
}
// Set stores a value in Redis cache with TTL
func (r *RedisCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
r.logger.Debug("Setting value in Redis cache",
zap.String("key", key),
zap.Duration("ttl", ttl))
err := r.client.Set(ctx, key, value, ttl).Err()
if err != nil {
r.logger.Error("Failed to set value in Redis", zap.Error(err))
return errors.NewInternalError("Failed to cache value").WithInternal(err)
}
return nil
}
// Delete removes a value from Redis cache
func (r *RedisCache) Delete(ctx context.Context, key string) error {
r.logger.Debug("Deleting value from Redis cache", zap.String("key", key))
err := r.client.Del(ctx, key).Err()
if err != nil {
r.logger.Error("Failed to delete value from Redis", zap.Error(err))
return errors.NewInternalError("Failed to delete cached value").WithInternal(err)
}
return nil
}
// Exists checks if a key exists in Redis cache
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
count, err := r.client.Exists(ctx, key).Result()
if err != nil {
r.logger.Error("Failed to check key existence in Redis", zap.Error(err))
return false, errors.NewInternalError("Failed to check cache key existence").WithInternal(err)
}
return count > 0, nil
}
// Clear removes all values from Redis cache (use with caution)
func (r *RedisCache) Clear(ctx context.Context) error {
r.logger.Warn("Clearing Redis cache - this will remove ALL cached data")
err := r.client.FlushDB(ctx).Err()
if err != nil {
r.logger.Error("Failed to clear Redis cache", zap.Error(err))
return errors.NewInternalError("Failed to clear cache").WithInternal(err)
}
return nil
}
// Close closes the Redis connection
func (r *RedisCache) Close() error {
r.logger.Info("Closing Redis connection")
return r.client.Close()
}
// SetNX sets a key only if it doesn't exist (Redis-specific operation)
func (r *RedisCache) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
r.logger.Debug("Setting value in Redis cache with NX",
zap.String("key", key),
zap.Duration("ttl", ttl))
result, err := r.client.SetNX(ctx, key, value, ttl).Result()
if err != nil {
r.logger.Error("Failed to set NX value in Redis", zap.Error(err))
return false, errors.NewInternalError("Failed to cache value with NX").WithInternal(err)
}
return result, nil
}
// Expire sets TTL for an existing key
func (r *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
r.logger.Debug("Setting TTL for Redis key",
zap.String("key", key),
zap.Duration("ttl", ttl))
result, err := r.client.Expire(ctx, key, ttl).Result()
if err != nil {
r.logger.Error("Failed to set TTL in Redis", zap.Error(err))
return errors.NewInternalError("Failed to set key TTL").WithInternal(err)
}
if !result {
return errors.NewNotFoundError("cache key")
}
return nil
}
// TTL returns the remaining time to live for a key
func (r *RedisCache) TTL(ctx context.Context, key string) (time.Duration, error) {
ttl, err := r.client.TTL(ctx, key).Result()
if err != nil {
r.logger.Error("Failed to get TTL from Redis", zap.Error(err))
return 0, errors.NewInternalError("Failed to get key TTL").WithInternal(err)
}
return ttl, nil
}
// Keys returns all keys matching a pattern
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
keys, err := r.client.Keys(ctx, pattern).Result()
if err != nil {
r.logger.Error("Failed to get keys from Redis", zap.Error(err))
return nil, errors.NewInternalError("Failed to get cache keys").WithInternal(err)
}
return keys, nil
}

View File

@ -117,6 +117,21 @@ func (c *Config) setDefaults() {
"INTERNAL_HMAC_KEY": "bootstrap-hmac-key-change-in-production",
"METRICS_ENABLED": "false",
"METRICS_PORT": "9090",
"REDIS_ENABLED": "false",
"REDIS_ADDR": "localhost:6379",
"REDIS_PASSWORD": "",
"REDIS_DB": "0",
"REDIS_POOL_SIZE": "10",
"REDIS_MIN_IDLE_CONNS": "5",
"REDIS_MAX_RETRIES": "3",
"REDIS_DIAL_TIMEOUT": "5s",
"REDIS_READ_TIMEOUT": "3s",
"REDIS_WRITE_TIMEOUT": "3s",
"MAX_AUTH_FAILURES": "5",
"AUTH_FAILURE_WINDOW": "15m",
"IP_BLOCK_DURATION": "1h",
"REQUEST_MAX_AGE": "5m",
"IP_WHITELIST": "",
}
for key, value := range defaults {

View File

@ -202,6 +202,11 @@ func NewAuthenticationError(message string) *AppError {
return New(ErrUnauthorized, message)
}
// NewConfigurationError creates a configuration error
func NewConfigurationError(message string) *AppError {
return New(ErrInternal, message)
}
// ErrorResponse represents the JSON error response format
type ErrorResponse struct {
Error string `json:"error"`

394
internal/handlers/oauth2.go Normal file
View File

@ -0,0 +1,394 @@
package handlers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/errors"
"github.com/kms/api-key-service/internal/services"
)
// OAuth2Handler handles OAuth2/OIDC authentication flows
type OAuth2Handler struct {
config config.ConfigProvider
logger *zap.Logger
oauth2Provider *auth.OAuth2Provider
authService services.AuthenticationService
}
// NewOAuth2Handler creates a new OAuth2 handler
func NewOAuth2Handler(
config config.ConfigProvider,
logger *zap.Logger,
authService services.AuthenticationService,
) *OAuth2Handler {
oauth2Provider := auth.NewOAuth2Provider(config, logger)
return &OAuth2Handler{
config: config,
logger: logger,
oauth2Provider: oauth2Provider,
authService: authService,
}
}
// AuthorizeRequest represents the OAuth2 authorization request
type AuthorizeRequest struct {
RedirectURI string `json:"redirect_uri" validate:"required,url"`
State string `json:"state,omitempty"`
}
// AuthorizeResponse represents the OAuth2 authorization response
type AuthorizeResponse struct {
AuthURL string `json:"auth_url"`
State string `json:"state"`
CodeVerifier string `json:"code_verifier"` // In production, this should be stored securely
}
// CallbackRequest represents the OAuth2 callback request
type CallbackRequest struct {
Code string `json:"code" validate:"required"`
State string `json:"state,omitempty"`
RedirectURI string `json:"redirect_uri" validate:"required,url"`
CodeVerifier string `json:"code_verifier" validate:"required"`
}
// CallbackResponse represents the OAuth2 callback response
type CallbackResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
UserInfo *auth.UserInfo `json:"user_info"`
JWTToken string `json:"jwt_token"`
}
// RefreshRequest represents the token refresh request
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" validate:"required"`
}
// RefreshResponse represents the token refresh response
type RefreshResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
JWTToken string `json:"jwt_token"`
}
// RegisterRoutes registers OAuth2 routes
func (h *OAuth2Handler) RegisterRoutes(router *mux.Router) {
oauth2Router := router.PathPrefix("/oauth2").Subrouter()
oauth2Router.HandleFunc("/authorize", h.Authorize).Methods("POST")
oauth2Router.HandleFunc("/callback", h.Callback).Methods("POST")
oauth2Router.HandleFunc("/refresh", h.Refresh).Methods("POST")
oauth2Router.HandleFunc("/userinfo", h.GetUserInfo).Methods("GET")
}
// Authorize initiates the OAuth2 authorization flow
func (h *OAuth2Handler) Authorize(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing OAuth2 authorization request")
var req AuthorizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.logger.Warn("Invalid authorization request", zap.Error(err))
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Generate state if not provided
if req.State == "" {
state, err := h.generateState()
if err != nil {
h.logger.Error("Failed to generate state", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
req.State = state
}
// Generate authorization URL
authURL, err := h.oauth2Provider.GenerateAuthURL(ctx, req.State, req.RedirectURI)
if err != nil {
h.logger.Error("Failed to generate authorization URL", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to generate authorization URL", http.StatusInternalServerError)
return
}
// In production, store the code verifier securely (e.g., in session or cache)
// For now, we'll return it in the response
codeVerifier, err := h.generateCodeVerifier()
if err != nil {
h.logger.Error("Failed to generate code verifier", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
response := AuthorizeResponse{
AuthURL: authURL,
State: req.State,
CodeVerifier: codeVerifier,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("Failed to encode authorization response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Debug("Authorization URL generated successfully",
zap.String("state", req.State),
zap.String("redirect_uri", req.RedirectURI))
}
// Callback handles the OAuth2 callback and exchanges code for tokens
func (h *OAuth2Handler) Callback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing OAuth2 callback")
var req CallbackRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.logger.Warn("Invalid callback request", zap.Error(err))
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Exchange authorization code for tokens
tokenResp, err := h.oauth2Provider.ExchangeCodeForToken(ctx, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
h.logger.Error("Failed to exchange code for token", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to exchange authorization code", http.StatusInternalServerError)
return
}
// Get user information
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
h.logger.Error("Failed to get user info", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
return
}
// Generate internal JWT token for the user
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
if err != nil {
h.logger.Error("Failed to generate internal JWT token", zap.Error(err))
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
return
}
response := CallbackResponse{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
UserInfo: userInfo,
JWTToken: jwtToken,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("Failed to encode callback response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Info("OAuth2 callback processed successfully",
zap.String("user_id", userInfo.Sub),
zap.String("email", userInfo.Email))
}
// Refresh refreshes an access token using refresh token
func (h *OAuth2Handler) Refresh(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing token refresh request")
var req RefreshRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.logger.Warn("Invalid refresh request", zap.Error(err))
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Refresh the access token
tokenResp, err := h.oauth2Provider.RefreshAccessToken(ctx, req.RefreshToken)
if err != nil {
h.logger.Error("Failed to refresh access token", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to refresh access token", http.StatusInternalServerError)
return
}
// Get updated user information
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
h.logger.Error("Failed to get user info during refresh", zap.Error(err))
if appErr, ok := err.(*errors.AppError); ok {
http.Error(w, appErr.Message, appErr.StatusCode)
return
}
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
return
}
// Generate new internal JWT token
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
if err != nil {
h.logger.Error("Failed to generate internal JWT token during refresh", zap.Error(err))
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
return
}
response := RefreshResponse{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
JWTToken: jwtToken,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("Failed to encode refresh response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Debug("Token refresh completed successfully",
zap.String("user_id", userInfo.Sub))
}
// GetUserInfo retrieves user information from the current session
func (h *OAuth2Handler) GetUserInfo(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.Debug("Processing user info request")
// Extract JWT token from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
// Remove "Bearer " prefix
tokenString := authHeader
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
tokenString = authHeader[7:]
}
// Validate JWT token
authContext, err := h.authService.ValidateJWTToken(ctx, tokenString)
if err != nil {
h.logger.Warn("Invalid JWT token in user info request", zap.Error(err))
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
// Return user information from JWT claims
userInfo := map[string]interface{}{
"sub": authContext.UserID,
"email": authContext.Claims["email"],
"name": authContext.Claims["name"],
"permissions": authContext.Permissions,
"app_id": authContext.AppID,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(userInfo); err != nil {
h.logger.Error("Failed to encode user info response", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
h.logger.Debug("User info request completed successfully",
zap.String("user_id", authContext.UserID))
}
// generateState generates a random state parameter for OAuth2
func (h *OAuth2Handler) generateState() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
// generateCodeVerifier generates a PKCE code verifier
func (h *OAuth2Handler) generateCodeVerifier() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
// generateInternalJWTToken generates an internal JWT token for authenticated users
func (h *OAuth2Handler) generateInternalJWTToken(ctx context.Context, userInfo *auth.UserInfo) (string, error) {
// Create user token with information from OAuth2 provider
userToken := &domain.UserToken{
AppID: h.config.GetString("INTERNAL_APP_ID"),
UserID: userInfo.Sub,
Permissions: []string{"read", "write"}, // Default permissions, should be based on user roles
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(24 * time.Hour), // 24 hour expiration
MaxValidAt: time.Now().Add(7 * 24 * time.Hour), // 7 days max validity
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"sub": userInfo.Sub,
"email": userInfo.Email,
"name": userInfo.Name,
"email_verified": func() string {
if userInfo.EmailVerified {
return "true"
}
return "false"
}(),
},
}
// Generate JWT token using authentication service
return h.authService.GenerateJWTToken(ctx, userToken)
}

View File

@ -0,0 +1,423 @@
package middleware
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"go.uber.org/zap"
"golang.org/x/time/rate"
"github.com/kms/api-key-service/internal/cache"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/errors"
)
// SecurityMiddleware provides various security features
type SecurityMiddleware struct {
config config.ConfigProvider
logger *zap.Logger
cacheManager *cache.CacheManager
rateLimiters map[string]*rate.Limiter
mu sync.RWMutex
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger) *SecurityMiddleware {
cacheManager := cache.NewCacheManager(config, logger)
return &SecurityMiddleware{
config: config,
logger: logger,
cacheManager: cacheManager,
rateLimiters: make(map[string]*rate.Limiter),
}
}
// RateLimitMiddleware implements per-IP rate limiting
func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
next.ServeHTTP(w, r)
return
}
// Get client IP
clientIP := s.getClientIP(r)
// Get or create rate limiter for this IP
limiter := s.getRateLimiter(clientIP)
// Check if request is allowed
if !limiter.Allow() {
s.logger.Warn("Rate limit exceeded",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
// Track rate limit violations
s.trackRateLimitViolation(clientIP)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`))
return
}
next.ServeHTTP(w, r)
})
}
// BruteForceProtectionMiddleware implements brute force protection
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := s.getClientIP(r)
// Check if IP is temporarily blocked
if s.isIPBlocked(clientIP) {
s.logger.Warn("Blocked IP attempted access",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`))
return
}
next.ServeHTTP(w, r)
})
}
// IPWhitelistMiddleware implements IP whitelisting
func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
whitelist := s.config.GetStringSlice("IP_WHITELIST")
if len(whitelist) == 0 {
// No whitelist configured, allow all
next.ServeHTTP(w, r)
return
}
clientIP := s.getClientIP(r)
// Check if IP is in whitelist
if !s.isIPInList(clientIP, whitelist) {
s.logger.Warn("Non-whitelisted IP attempted access",
zap.String("client_ip", clientIP),
zap.String("path", r.URL.Path))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`))
return
}
next.ServeHTTP(w, r)
})
}
// SecurityHeadersMiddleware adds security headers
func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
// Add HSTS header for HTTPS
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
next.ServeHTTP(w, r)
})
}
// AuthenticationFailureTracker tracks authentication failures for brute force protection
func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) {
ctx := context.Background()
// Track failures by IP
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
s.incrementFailureCount(ctx, ipKey)
// Track failures by user ID if provided
if userID != "" {
userKey := cache.CacheKey("auth_failures_user", userID)
s.incrementFailureCount(ctx, userKey)
}
// Check if we should block the IP
s.checkAndBlockIP(clientIP)
}
// ClearAuthenticationFailures clears failure count on successful authentication
func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) {
ctx := context.Background()
// Clear failures by IP
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
s.cacheManager.Delete(ctx, ipKey)
// Clear failures by user ID if provided
if userID != "" {
userKey := cache.CacheKey("auth_failures_user", userID)
s.cacheManager.Delete(ctx, userKey)
}
}
// Helper methods
func (s *SecurityMiddleware) getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
// Take the first IP in the chain
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return xri
}
// Fall back to RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
s.mu.RLock()
limiter, exists := s.rateLimiters[clientIP]
s.mu.RUnlock()
if exists {
return limiter
}
// Create new rate limiter
rps := s.config.GetInt("RATE_LIMIT_RPS")
if rps <= 0 {
rps = 100 // Default
}
burst := s.config.GetInt("RATE_LIMIT_BURST")
if burst <= 0 {
burst = 200 // Default
}
limiter = rate.NewLimiter(rate.Limit(rps), burst)
s.mu.Lock()
s.rateLimiters[clientIP] = limiter
s.mu.Unlock()
return limiter
}
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("rate_limit_violations", clientIP)
s.incrementFailureCount(ctx, key)
}
func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool {
ctx := context.Background()
key := cache.CacheKey("blocked_ips", clientIP)
exists, err := s.cacheManager.Exists(ctx, key)
if err != nil {
s.logger.Error("Failed to check IP block status",
zap.String("client_ip", clientIP),
zap.Error(err))
return false
}
return exists
}
func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool {
for _, allowedIP := range ipList {
allowedIP = strings.TrimSpace(allowedIP)
// Support CIDR notation
if strings.Contains(allowedIP, "/") {
_, network, err := net.ParseCIDR(allowedIP)
if err != nil {
s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP))
continue
}
ip := net.ParseIP(clientIP)
if ip != nil && network.Contains(ip) {
return true
}
} else {
// Exact IP match
if clientIP == allowedIP {
return true
}
}
}
return false
}
func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) {
// Get current count
var count int
err := s.cacheManager.GetJSON(ctx, key, &count)
if err != nil {
// Key doesn't exist, start with 0
count = 0
}
count++
// Store updated count with TTL
ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW")
if ttl <= 0 {
ttl = 15 * time.Minute // Default window
}
s.cacheManager.SetJSON(ctx, key, count, ttl)
}
func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) {
ctx := context.Background()
key := cache.CacheKey("auth_failures_ip", clientIP)
var count int
err := s.cacheManager.GetJSON(ctx, key, &count)
if err != nil {
return // No failures recorded
}
maxFailures := s.config.GetInt("MAX_AUTH_FAILURES")
if maxFailures <= 0 {
maxFailures = 5 // Default
}
if count >= maxFailures {
// Block the IP
blockKey := cache.CacheKey("blocked_ips", clientIP)
blockDuration := s.config.GetDuration("IP_BLOCK_DURATION")
if blockDuration <= 0 {
blockDuration = 1 * time.Hour // Default
}
blockInfo := map[string]interface{}{
"blocked_at": time.Now().Unix(),
"failure_count": count,
"reason": "excessive_auth_failures",
}
s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration)
s.logger.Warn("IP blocked due to excessive authentication failures",
zap.String("client_ip", clientIP),
zap.Int("failure_count", count),
zap.Duration("block_duration", blockDuration))
}
}
// RequestSignatureMiddleware validates request signatures (for API key requests)
func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only validate signatures for certain endpoints
if !s.shouldValidateSignature(r) {
next.ServeHTTP(w, r)
return
}
signature := r.Header.Get("X-Signature")
timestamp := r.Header.Get("X-Timestamp")
if signature == "" || timestamp == "" {
s.logger.Warn("Missing signature headers",
zap.String("path", r.URL.Path),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`))
return
}
// Validate timestamp (prevent replay attacks)
if !s.isTimestampValid(timestamp) {
s.logger.Warn("Invalid timestamp in request",
zap.String("timestamp", timestamp),
zap.String("client_ip", s.getClientIP(r)))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`))
return
}
// TODO: Implement actual signature validation
// This would involve validating the HMAC signature using the client's secret
next.ServeHTTP(w, r)
})
}
func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool {
// Define which endpoints require signature validation
signatureRequiredPaths := []string{
"/api/v1/tokens",
"/api/v1/applications",
}
for _, path := range signatureRequiredPaths {
if strings.HasPrefix(r.URL.Path, path) {
return true
}
}
return false
}
func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
// Parse timestamp
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return false
}
// Check if timestamp is within acceptable window
now := time.Now()
maxAge := s.config.GetDuration("REQUEST_MAX_AGE")
if maxAge <= 0 {
maxAge = 5 * time.Minute // Default
}
return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute))
}
// GetSecurityMetrics returns security-related metrics
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
ctx := context.Background()
// This is a simplified version - in production you'd want more comprehensive metrics
metrics := map[string]interface{}{
"active_rate_limiters": len(s.rateLimiters),
"timestamp": time.Now().Unix(),
}
// Count blocked IPs (this is expensive, so you might want to cache this)
// For now, we'll just return the basic metrics
return metrics
}

View File

@ -14,231 +14,7 @@ import (
"github.com/kms/api-key-service/internal/services"
)
// MockConfig implements ConfigProvider for testing
type MockConfig struct {
values map[string]string
}
func NewMockConfig() *MockConfig {
return &MockConfig{
values: map[string]string{
"JWT_SECRET": "test-jwt-secret-for-testing-only",
},
}
}
func (m *MockConfig) GetString(key string) string {
return m.values[key]
}
func (m *MockConfig) GetInt(key string) int { return 0 }
func (m *MockConfig) GetBool(key string) bool { return false }
func (m *MockConfig) GetDuration(key string) time.Duration { return 0 }
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
func (m *MockConfig) Validate() error { return nil }
func (m *MockConfig) GetDatabaseDSN() string { return "" }
func (m *MockConfig) GetServerAddress() string { return "" }
func (m *MockConfig) GetMetricsAddress() string { return "" }
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
func (m *MockConfig) IsDevelopment() bool { return true }
func (m *MockConfig) IsProduction() bool { return false }
func TestJWTManager_GenerateToken(t *testing.T) {
config := NewMockConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)
userToken := &domain.UserToken{
AppID: "test-app",
UserID: "test-user",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"email": "test@example.com",
"name": "Test User",
},
}
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
assert.NotEmpty(t, tokenString)
// Verify the token can be validated
claims, err := jwtManager.ValidateToken(tokenString)
require.NoError(t, err)
assert.Equal(t, userToken.UserID, claims.UserID)
assert.Equal(t, userToken.AppID, claims.AppID)
assert.Equal(t, userToken.Permissions, claims.Permissions)
assert.Equal(t, userToken.TokenType, claims.TokenType)
assert.Equal(t, userToken.Claims, claims.Claims)
}
func TestJWTManager_ValidateToken(t *testing.T) {
config := NewMockConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)
userToken := &domain.UserToken{
AppID: "test-app",
UserID: "test-user",
Permissions: []string{"read"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Test valid token
claims, err := jwtManager.ValidateToken(tokenString)
require.NoError(t, err)
assert.Equal(t, userToken.UserID, claims.UserID)
assert.Equal(t, userToken.AppID, claims.AppID)
// Test invalid token
_, err = jwtManager.ValidateToken("invalid-token")
assert.Error(t, err)
// Test empty token
_, err = jwtManager.ValidateToken("")
assert.Error(t, err)
}
func TestJWTManager_ExpiredToken(t *testing.T) {
config := NewMockConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)
// Create an expired token
userToken := &domain.UserToken{
AppID: "test-app",
UserID: "test-user",
Permissions: []string{"read"},
IssuedAt: time.Now().Add(-2 * time.Hour),
ExpiresAt: time.Now().Add(-time.Hour), // Expired 1 hour ago
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Validation should fail for expired token
_, err = jwtManager.ValidateToken(tokenString)
assert.Error(t, err)
}
func TestJWTManager_MaxValidAtExpired(t *testing.T) {
config := NewMockConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)
// Create a token that's past max valid time
userToken := &domain.UserToken{
AppID: "test-app",
UserID: "test-user",
Permissions: []string{"read"},
IssuedAt: time.Now().Add(-2 * time.Hour),
ExpiresAt: time.Now().Add(time.Hour),
MaxValidAt: time.Now().Add(-time.Hour), // Max valid time expired
TokenType: domain.TokenTypeUser,
}
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Validation should fail for token past max valid time
_, err = jwtManager.ValidateToken(tokenString)
assert.Error(t, err)
}
func TestJWTManager_RefreshToken(t *testing.T) {
config := NewMockConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)
userToken := &domain.UserToken{
AppID: "test-app",
UserID: "test-user",
Permissions: []string{"read"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
originalToken, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Refresh the token
newExpiration := time.Now().Add(2 * time.Hour)
refreshedToken, err := jwtManager.RefreshToken(originalToken, newExpiration)
require.NoError(t, err)
assert.NotEmpty(t, refreshedToken)
assert.NotEqual(t, originalToken, refreshedToken)
// Validate the refreshed token
claims, err := jwtManager.ValidateToken(refreshedToken)
require.NoError(t, err)
assert.Equal(t, userToken.UserID, claims.UserID)
assert.Equal(t, userToken.AppID, claims.AppID)
}
func TestJWTManager_ExtractClaims(t *testing.T) {
config := NewMockConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)
userToken := &domain.UserToken{
AppID: "test-app",
UserID: "test-user",
Permissions: []string{"read"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(-time.Hour), // Expired token
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Extract claims from expired token (should work)
claims, err := jwtManager.ExtractClaims(tokenString)
require.NoError(t, err)
assert.Equal(t, userToken.UserID, claims.UserID)
assert.Equal(t, userToken.AppID, claims.AppID)
}
func TestJWTManager_GetTokenInfo(t *testing.T) {
config := NewMockConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)
userToken := &domain.UserToken{
AppID: "test-app",
UserID: "test-user",
Permissions: []string{"read"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
info := jwtManager.GetTokenInfo(tokenString)
assert.Equal(t, userToken.UserID, info["user_id"])
assert.Equal(t, userToken.AppID, info["app_id"])
assert.Equal(t, userToken.Permissions, info["permissions"])
assert.Equal(t, userToken.TokenType, info["token_type"])
}
func TestAuthenticationService_ValidateJWTToken(t *testing.T) {
config := NewMockConfig()
@ -330,7 +106,8 @@ func TestAuthenticationService_RefreshJWTToken(t *testing.T) {
func TestJWTManager_InvalidSecret(t *testing.T) {
// Test with empty JWT secret
config := &MockConfig{values: map[string]string{"JWT_SECRET": ""}}
config := NewTestConfig()
config.values["JWT_SECRET"] = ""
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(config, logger)

View File

@ -12,48 +12,6 @@ import (
"github.com/kms/api-key-service/internal/cache"
)
// MockConfig implements ConfigProvider for testing
type MockConfig struct {
values map[string]string
}
func NewMockConfig() *MockConfig {
return &MockConfig{
values: map[string]string{
"CACHE_ENABLED": "true",
"CACHE_TTL": "1h",
},
}
}
func (m *MockConfig) GetString(key string) string {
return m.values[key]
}
func (m *MockConfig) GetInt(key string) int { return 0 }
func (m *MockConfig) GetBool(key string) bool {
if key == "CACHE_ENABLED" {
return m.values[key] == "true"
}
return false
}
func (m *MockConfig) GetDuration(key string) time.Duration {
if key == "CACHE_TTL" {
if d, err := time.ParseDuration(m.values[key]); err == nil {
return d
}
}
return 0
}
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
func (m *MockConfig) Validate() error { return nil }
func (m *MockConfig) GetDatabaseDSN() string { return "" }
func (m *MockConfig) GetServerAddress() string { return "" }
func (m *MockConfig) GetMetricsAddress() string { return "" }
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
func (m *MockConfig) IsDevelopment() bool { return true }
func (m *MockConfig) IsProduction() bool { return false }
func TestMemoryCache_SetAndGet(t *testing.T) {
config := NewMockConfig()
@ -315,12 +273,9 @@ func TestCacheKeyPrefixes(t *testing.T) {
func TestCacheManager_ConfigMethods(t *testing.T) {
// Create mock config with cache settings
config := &MockConfig{
values: map[string]string{
"CACHE_ENABLED": "true",
"CACHE_TTL": "1h",
},
}
config := NewMockConfig()
config.values["CACHE_ENABLED"] = "true"
config.values["CACHE_TTL"] = "1h"
logger := zap.NewNop()
cacheManager := cache.NewCacheManager(config, logger)
defer cacheManager.Close()

382
test/jwt_test.go Normal file
View File

@ -0,0 +1,382 @@
package test
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/domain"
)
func TestJWTManager_GenerateToken(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"department": "engineering",
"role": "developer",
},
}
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
assert.NotEmpty(t, tokenString)
// Verify token structure (should have 3 parts separated by dots)
parts := len(tokenString)
assert.Greater(t, parts, 100) // JWT tokens are typically longer than 100 chars
}
func TestJWTManager_ValidateToken(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"department": "engineering",
},
}
// Generate token
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Validate token
claims, err := jwtManager.ValidateToken(tokenString)
require.NoError(t, err)
assert.Equal(t, userToken.UserID, claims.UserID)
assert.Equal(t, userToken.AppID, claims.AppID)
assert.Equal(t, userToken.Permissions, claims.Permissions)
assert.Equal(t, userToken.TokenType, claims.TokenType)
assert.Equal(t, userToken.Claims, claims.Claims)
}
func TestJWTManager_ValidateToken_InvalidToken(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
// Test with invalid token
_, err := jwtManager.ValidateToken("invalid.token.here")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid token")
}
func TestJWTManager_ValidateToken_ExpiredToken(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read"},
IssuedAt: time.Now().Add(-2 * time.Hour),
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired 1 hour ago
MaxValidAt: time.Now().Add(-30 * time.Minute), // Max valid also expired
TokenType: domain.TokenTypeUser,
}
// Generate token (this should work even with past dates)
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Validate token (this should fail due to expiration)
_, err = jwtManager.ValidateToken(tokenString)
assert.Error(t, err)
// The error could be either JWT expiration or our custom max valid check
assert.True(t,
strings.Contains(err.Error(), "expired beyond maximum validity") ||
strings.Contains(err.Error(), "token is expired"),
"Expected expiration error, got: %s", err.Error())
}
func TestJWTManager_RefreshToken(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
// Generate original token
originalToken, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Refresh token with new expiration
newExpiration := time.Now().Add(2 * time.Hour)
refreshedToken, err := jwtManager.RefreshToken(originalToken, newExpiration)
require.NoError(t, err)
assert.NotEmpty(t, refreshedToken)
assert.NotEqual(t, originalToken, refreshedToken)
// Validate refreshed token
claims, err := jwtManager.ValidateToken(refreshedToken)
require.NoError(t, err)
assert.Equal(t, userToken.UserID, claims.UserID)
assert.Equal(t, userToken.AppID, claims.AppID)
}
func TestJWTManager_RefreshToken_ExpiredMaxValid(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read"},
IssuedAt: time.Now().Add(-2 * time.Hour),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(-30 * time.Minute), // Max valid expired
TokenType: domain.TokenTypeUser,
}
// Generate token
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Try to refresh (should fail due to max valid expiration)
newExpiration := time.Now().Add(2 * time.Hour)
_, err = jwtManager.RefreshToken(tokenString, newExpiration)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired beyond maximum validity")
}
func TestJWTManager_ExtractClaims(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired token
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
// Generate expired token
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Extract claims (should work even for expired tokens)
claims, err := jwtManager.ExtractClaims(tokenString)
require.NoError(t, err)
assert.Equal(t, userToken.UserID, claims.UserID)
assert.Equal(t, userToken.AppID, claims.AppID)
assert.Equal(t, userToken.Permissions, claims.Permissions)
}
func TestJWTManager_RevokeToken(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
// Generate token
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Revoke token
err = jwtManager.RevokeToken(tokenString)
assert.NoError(t, err)
// Check if token is revoked
revoked, err := jwtManager.IsTokenRevoked(tokenString)
assert.NoError(t, err)
assert.True(t, revoked)
}
func TestJWTManager_RevokeToken_AlreadyExpired(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read"},
IssuedAt: time.Now().Add(-2 * time.Hour),
ExpiresAt: time.Now().Add(-1 * time.Hour), // Already expired
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
// Generate expired token
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Revoke expired token (should succeed but not add to blacklist)
err = jwtManager.RevokeToken(tokenString)
assert.NoError(t, err)
// Check if token is revoked (should be false since it was already expired)
revoked, err := jwtManager.IsTokenRevoked(tokenString)
assert.NoError(t, err)
assert.False(t, revoked)
}
func TestJWTManager_IsTokenRevoked_NotRevoked(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
// Generate token
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Check if token is revoked (should be false)
revoked, err := jwtManager.IsTokenRevoked(tokenString)
assert.NoError(t, err)
assert.False(t, revoked)
}
func TestJWTManager_GetTokenInfo(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
Claims: map[string]string{
"department": "engineering",
},
}
// Generate token
tokenString, err := jwtManager.GenerateToken(userToken)
require.NoError(t, err)
// Get token info
info := jwtManager.GetTokenInfo(tokenString)
assert.Equal(t, userToken.UserID, info["user_id"])
assert.Equal(t, userToken.AppID, info["app_id"])
assert.Equal(t, userToken.Permissions, info["permissions"])
assert.Equal(t, userToken.TokenType, info["token_type"])
assert.NotNil(t, info["issued_at"])
assert.NotNil(t, info["expires_at"])
assert.NotNil(t, info["max_valid_at"])
assert.NotNil(t, info["jti"])
}
func TestJWTManager_GetTokenInfo_InvalidToken(t *testing.T) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
// Get info for invalid token
info := jwtManager.GetTokenInfo("invalid.token.here")
assert.Contains(t, info["error"], "Invalid token format")
}
// Benchmark tests
func BenchmarkJWTManager_GenerateToken(b *testing.B) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := jwtManager.GenerateToken(userToken)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkJWTManager_ValidateToken(b *testing.B) {
cfg := config.NewConfig()
logger := zap.NewNop()
jwtManager := auth.NewJWTManager(cfg, logger)
userToken := &domain.UserToken{
UserID: "test-user-123",
AppID: "test-app-456",
Permissions: []string{"read", "write"},
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(1 * time.Hour),
MaxValidAt: time.Now().Add(24 * time.Hour),
TokenType: domain.TokenTypeUser,
}
tokenString, err := jwtManager.GenerateToken(userToken)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := jwtManager.ValidateToken(tokenString)
if err != nil {
b.Fatal(err)
}
}
}

552
test/oauth2_test.go Normal file
View File

@ -0,0 +1,552 @@
package test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
)
func TestOAuth2Provider_GetDiscoveryDocument(t *testing.T) {
tests := []struct {
name string
providerURL string
mockResponse string
mockStatusCode int
expectError bool
expectedIssuer string
}{
{
name: "successful discovery",
providerURL: "https://example.com",
mockResponse: `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo",
"jwks_uri": "https://example.com/jwks"
}`,
mockStatusCode: http.StatusOK,
expectError: false,
expectedIssuer: "https://example.com",
},
{
name: "missing provider URL",
providerURL: "",
expectError: true,
},
{
name: "invalid response status",
providerURL: "https://example.com",
mockResponse: `{"error": "not found"}`,
mockStatusCode: http.StatusNotFound,
expectError: true,
},
{
name: "invalid JSON response",
providerURL: "https://example.com",
mockResponse: `invalid json`,
mockStatusCode: http.StatusOK,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock server if needed
var server *httptest.Server
if tt.providerURL != "" && !tt.expectError {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/.well-known/openid_configuration", r.URL.Path)
w.WriteHeader(tt.mockStatusCode)
w.Write([]byte(tt.mockResponse))
}))
defer server.Close()
tt.providerURL = server.URL
}
// Create config mock
configMock := NewMockConfig()
configMock.values["SSO_PROVIDER_URL"] = tt.providerURL
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
discovery, err := provider.GetDiscoveryDocument(ctx)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, discovery)
} else {
assert.NoError(t, err)
assert.NotNil(t, discovery)
assert.Equal(t, tt.expectedIssuer, discovery.Issuer)
}
})
}
}
func TestOAuth2Provider_GenerateAuthURL(t *testing.T) {
// Create mock discovery server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo"
}`
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
defer server.Close()
tests := []struct {
name string
clientID string
state string
redirectURI string
expectError bool
}{
{
name: "successful URL generation",
clientID: "test-client-id",
state: "test-state",
redirectURI: "https://app.example.com/callback",
expectError: false,
},
{
name: "missing client ID",
clientID: "",
state: "test-state",
redirectURI: "https://app.example.com/callback",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configMock := NewMockConfig()
configMock.values["SSO_PROVIDER_URL"] = server.URL
configMock.values["SSO_CLIENT_ID"] = tt.clientID
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
authURL, err := provider.GenerateAuthURL(ctx, tt.state, tt.redirectURI)
if tt.expectError {
assert.Error(t, err)
assert.Empty(t, authURL)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, authURL)
assert.Contains(t, authURL, "https://example.com/auth")
assert.Contains(t, authURL, "client_id="+tt.clientID)
assert.Contains(t, authURL, "state="+tt.state)
assert.Contains(t, authURL, "redirect_uri=")
}
})
}
}
func TestOAuth2Provider_ExchangeCodeForToken(t *testing.T) {
tests := []struct {
name string
code string
redirectURI string
codeVerifier string
clientID string
clientSecret string
mockResponse string
mockStatusCode int
expectError bool
expectedToken string
}{
{
name: "successful token exchange",
code: "test-code",
redirectURI: "https://app.example.com/callback",
codeVerifier: "test-verifier",
clientID: "test-client-id",
clientSecret: "test-client-secret",
mockResponse: `{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "test-refresh-token"
}`,
mockStatusCode: http.StatusOK,
expectError: false,
expectedToken: "test-access-token",
},
{
name: "missing client ID",
code: "test-code",
redirectURI: "https://app.example.com/callback",
codeVerifier: "test-verifier",
clientID: "",
clientSecret: "test-client-secret",
expectError: true,
},
{
name: "token endpoint error",
code: "test-code",
redirectURI: "https://app.example.com/callback",
codeVerifier: "test-verifier",
clientID: "test-client-id",
clientSecret: "test-client-secret",
mockResponse: `{"error": "invalid_grant"}`,
mockStatusCode: http.StatusBadRequest,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock servers
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo"
}`
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
defer discoveryServer.Close()
var tokenServer *httptest.Server
if !tt.expectError {
tokenServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
w.WriteHeader(tt.mockStatusCode)
w.Write([]byte(tt.mockResponse))
}))
defer tokenServer.Close()
// Update discovery server to return the token server URL
discoveryServer.Close()
discoveryServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "` + tokenServer.URL + `",
"userinfo_endpoint": "https://example.com/userinfo"
}`
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
}
configMock := NewMockConfig()
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
configMock.values["SSO_CLIENT_ID"] = tt.clientID
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
tokenResp, err := provider.ExchangeCodeForToken(ctx, tt.code, tt.redirectURI, tt.codeVerifier)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, tokenResp)
} else {
assert.NoError(t, err)
assert.NotNil(t, tokenResp)
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
assert.Equal(t, "Bearer", tokenResp.TokenType)
}
})
}
}
func TestOAuth2Provider_GetUserInfo(t *testing.T) {
tests := []struct {
name string
accessToken string
mockResponse string
mockStatusCode int
expectError bool
expectedSub string
expectedEmail string
}{
{
name: "successful user info retrieval",
accessToken: "test-access-token",
mockResponse: `{
"sub": "user123",
"email": "user@example.com",
"name": "Test User",
"email_verified": true
}`,
mockStatusCode: http.StatusOK,
expectError: false,
expectedSub: "user123",
expectedEmail: "user@example.com",
},
{
name: "unauthorized access token",
accessToken: "invalid-token",
mockResponse: `{"error": "invalid_token"}`,
mockStatusCode: http.StatusUnauthorized,
expectError: true,
},
{
name: "invalid JSON response",
accessToken: "test-access-token",
mockResponse: `invalid json`,
mockStatusCode: http.StatusOK,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock servers
userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, "Bearer "+tt.accessToken, r.Header.Get("Authorization"))
w.WriteHeader(tt.mockStatusCode)
w.Write([]byte(tt.mockResponse))
}))
defer userInfoServer.Close()
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "` + userInfoServer.URL + `"
}`
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
defer discoveryServer.Close()
configMock := NewMockConfig()
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
userInfo, err := provider.GetUserInfo(ctx, tt.accessToken)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, userInfo)
} else {
assert.NoError(t, err)
assert.NotNil(t, userInfo)
assert.Equal(t, tt.expectedSub, userInfo.Sub)
assert.Equal(t, tt.expectedEmail, userInfo.Email)
}
})
}
}
func TestOAuth2Provider_ValidateIDToken(t *testing.T) {
tests := []struct {
name string
idToken string
expectError bool
expectedSub string
}{
{
name: "valid ID token",
// This is a mock JWT token with payload: {"sub": "user123", "email": "user@example.com", "name": "Test User"}
idToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIiwiZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwibmFtZSI6IlRlc3QgVXNlciJ9.invalid-signature",
expectError: false,
expectedSub: "user123",
},
{
name: "invalid token format",
idToken: "invalid.token",
expectError: true,
},
{
name: "empty token",
idToken: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configMock := NewMockConfig()
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
authContext, err := provider.ValidateIDToken(ctx, tt.idToken)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, authContext)
} else {
assert.NoError(t, err)
assert.NotNil(t, authContext)
assert.Equal(t, tt.expectedSub, authContext.UserID)
}
})
}
}
func TestOAuth2Provider_RefreshAccessToken(t *testing.T) {
tests := []struct {
name string
refreshToken string
clientID string
clientSecret string
mockResponse string
mockStatusCode int
expectError bool
expectedToken string
}{
{
name: "successful token refresh",
refreshToken: "test-refresh-token",
clientID: "test-client-id",
clientSecret: "test-client-secret",
mockResponse: `{
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token"
}`,
mockStatusCode: http.StatusOK,
expectError: false,
expectedToken: "new-access-token",
},
{
name: "invalid refresh token",
refreshToken: "invalid-refresh-token",
clientID: "test-client-id",
clientSecret: "test-client-secret",
mockResponse: `{"error": "invalid_grant"}`,
mockStatusCode: http.StatusBadRequest,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock servers
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
w.WriteHeader(tt.mockStatusCode)
w.Write([]byte(tt.mockResponse))
}))
defer tokenServer.Close()
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "` + tokenServer.URL + `",
"userinfo_endpoint": "https://example.com/userinfo"
}`
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
defer discoveryServer.Close()
configMock := NewMockConfig()
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
configMock.values["SSO_CLIENT_ID"] = tt.clientID
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
tokenResp, err := provider.RefreshAccessToken(ctx, tt.refreshToken)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, tokenResp)
} else {
assert.NoError(t, err)
assert.NotNil(t, tokenResp)
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
assert.Equal(t, "Bearer", tokenResp.TokenType)
}
})
}
}
// Benchmark tests for OAuth2 operations
func BenchmarkOAuth2Provider_GetDiscoveryDocument(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo"
}`
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
defer server.Close()
configMock := NewMockConfig()
configMock.values["SSO_PROVIDER_URL"] = server.URL
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := provider.GetDiscoveryDocument(ctx)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkOAuth2Provider_GenerateAuthURL(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo"
}`
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
defer server.Close()
configMock := NewMockConfig()
configMock.values["SSO_PROVIDER_URL"] = server.URL
configMock.values["SSO_CLIENT_ID"] = "test-client-id"
logger := zap.NewNop()
provider := auth.NewOAuth2Provider(configMock, logger)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := provider.GenerateAuthURL(ctx, "test-state", "https://app.example.com/callback")
if err != nil {
b.Fatal(err)
}
}
}

594
test/permissions_test.go Normal file
View File

@ -0,0 +1,594 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/auth"
)
func TestPermissionHierarchy_InitializeDefaultPermissions(t *testing.T) {
hierarchy := auth.NewPermissionHierarchy()
// Test that default permissions are created
permissions := hierarchy.ListPermissions()
assert.NotEmpty(t, permissions)
// Test specific permissions exist
permissionNames := make(map[string]bool)
for _, perm := range permissions {
permissionNames[perm.Name] = true
}
expectedPermissions := []string{
"admin", "read", "write",
"app.admin", "app.read", "app.write", "app.create", "app.update", "app.delete",
"token.admin", "token.read", "token.write", "token.create", "token.revoke", "token.verify",
"permission.admin", "permission.read", "permission.write", "permission.grant", "permission.revoke",
"user.admin", "user.read", "user.write",
}
for _, expected := range expectedPermissions {
assert.True(t, permissionNames[expected], "Permission %s should exist", expected)
}
}
func TestPermissionHierarchy_InitializeDefaultRoles(t *testing.T) {
hierarchy := auth.NewPermissionHierarchy()
// Test that default roles are created
roles := hierarchy.ListRoles()
assert.NotEmpty(t, roles)
// Test specific roles exist
roleNames := make(map[string]bool)
for _, role := range roles {
roleNames[role.Name] = true
}
expectedRoles := []string{
"super_admin", "app_admin", "developer", "viewer", "token_manager",
}
for _, expected := range expectedRoles {
assert.True(t, roleNames[expected], "Role %s should exist", expected)
}
}
func TestPermissionManager_HasPermission(t *testing.T) {
configMock := NewTestConfig()
configMock.values["CACHE_ENABLED"] = "false" // Disable cache for testing
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
tests := []struct {
name string
userID string
appID string
permission string
expectedResult bool
description string
}{
{
name: "admin user has admin permission",
userID: "admin@example.com",
appID: "test-app",
permission: "admin",
expectedResult: true,
description: "Admin users should have admin permissions",
},
{
name: "developer user has token.create permission",
userID: "dev@example.com",
appID: "test-app",
permission: "token.create",
expectedResult: true,
description: "Developer users should have token creation permissions",
},
{
name: "viewer user has read permission",
userID: "viewer@example.com",
appID: "test-app",
permission: "app.read",
expectedResult: true,
description: "Viewer users should have read permissions",
},
{
name: "viewer user denied write permission",
userID: "viewer@example.com",
appID: "test-app",
permission: "app.write",
expectedResult: false,
description: "Viewer users should not have write permissions",
},
{
name: "non-existent permission",
userID: "admin@example.com",
appID: "test-app",
permission: "non.existent",
expectedResult: false,
description: "Non-existent permissions should be denied",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
evaluation, err := pm.HasPermission(ctx, tt.userID, tt.appID, tt.permission)
require.NoError(t, err)
assert.NotNil(t, evaluation)
assert.Equal(t, tt.expectedResult, evaluation.Granted, tt.description)
assert.Equal(t, tt.permission, evaluation.Permission)
assert.NotZero(t, evaluation.EvaluatedAt)
if evaluation.Granted {
assert.NotEmpty(t, evaluation.GrantedBy, "Granted permissions should have GrantedBy information")
} else {
assert.NotEmpty(t, evaluation.DeniedReason, "Denied permissions should have a reason")
}
})
}
}
func TestPermissionManager_EvaluateBulkPermissions(t *testing.T) {
configMock := NewTestConfig()
configMock.values["CACHE_ENABLED"] = "false"
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
ctx := context.Background()
req := &auth.BulkPermissionRequest{
UserID: "dev@example.com",
AppID: "test-app",
Permissions: []string{
"app.read",
"token.create",
"token.read",
"app.delete", // Should be denied for developer
"admin", // Should be denied for developer
},
}
response, err := pm.EvaluateBulkPermissions(ctx, req)
require.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, req.UserID, response.UserID)
assert.Equal(t, req.AppID, response.AppID)
assert.Len(t, response.Results, len(req.Permissions))
// Check specific results
assert.True(t, response.Results["app.read"].Granted, "Developer should have app.read permission")
assert.True(t, response.Results["token.create"].Granted, "Developer should have token.create permission")
assert.True(t, response.Results["token.read"].Granted, "Developer should have token.read permission")
assert.False(t, response.Results["app.delete"].Granted, "Developer should not have app.delete permission")
assert.False(t, response.Results["admin"].Granted, "Developer should not have admin permission")
}
func TestPermissionManager_AddPermission(t *testing.T) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
tests := []struct {
name string
permission *auth.Permission
expectError bool
description string
}{
{
name: "add valid permission",
permission: &auth.Permission{
Name: "custom.permission",
Description: "Custom permission for testing",
Parent: "read",
Level: 2,
Resource: "custom",
Action: "test",
},
expectError: false,
description: "Valid permissions should be added successfully",
},
{
name: "add permission without name",
permission: &auth.Permission{
Description: "Permission without name",
Parent: "read",
Level: 2,
},
expectError: true,
description: "Permissions without names should be rejected",
},
{
name: "add permission with non-existent parent",
permission: &auth.Permission{
Name: "invalid.permission",
Description: "Permission with invalid parent",
Parent: "non.existent",
Level: 2,
},
expectError: true,
description: "Permissions with non-existent parents should be rejected",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := pm.AddPermission(tt.permission)
if tt.expectError {
assert.Error(t, err, tt.description)
} else {
assert.NoError(t, err, tt.description)
// Verify permission was added
permissions := pm.ListPermissions()
found := false
for _, perm := range permissions {
if perm.Name == tt.permission.Name {
found = true
assert.Equal(t, tt.permission.Description, perm.Description)
assert.Equal(t, tt.permission.Parent, perm.Parent)
break
}
}
assert.True(t, found, "Added permission should be found in the list")
}
})
}
}
func TestPermissionManager_AddRole(t *testing.T) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
tests := []struct {
name string
role *auth.Role
expectError bool
description string
}{
{
name: "add valid role",
role: &auth.Role{
Name: "custom_role",
Description: "Custom role for testing",
Permissions: []string{"read", "app.read"},
Metadata: map[string]string{"level": "custom"},
},
expectError: false,
description: "Valid roles should be added successfully",
},
{
name: "add role without name",
role: &auth.Role{
Description: "Role without name",
Permissions: []string{"read"},
},
expectError: true,
description: "Roles without names should be rejected",
},
{
name: "add role with non-existent permission",
role: &auth.Role{
Name: "invalid_role",
Description: "Role with invalid permission",
Permissions: []string{"non.existent.permission"},
},
expectError: true,
description: "Roles with non-existent permissions should be rejected",
},
{
name: "add role with non-existent inherited role",
role: &auth.Role{
Name: "invalid_inherited_role",
Description: "Role with invalid inheritance",
Permissions: []string{"read"},
Inherits: []string{"non_existent_role"},
},
expectError: true,
description: "Roles with non-existent inherited roles should be rejected",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := pm.AddRole(tt.role)
if tt.expectError {
assert.Error(t, err, tt.description)
} else {
assert.NoError(t, err, tt.description)
// Verify role was added
roles := pm.ListRoles()
found := false
for _, role := range roles {
if role.Name == tt.role.Name {
found = true
assert.Equal(t, tt.role.Description, role.Description)
assert.Equal(t, tt.role.Permissions, role.Permissions)
break
}
}
assert.True(t, found, "Added role should be found in the list")
}
})
}
}
func TestPermissionManager_ListPermissions(t *testing.T) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
permissions := pm.ListPermissions()
// Should have default permissions
assert.NotEmpty(t, permissions)
// Should be sorted by level and name
for i := 1; i < len(permissions); i++ {
prev := permissions[i-1]
curr := permissions[i]
if prev.Level == curr.Level {
assert.True(t, prev.Name <= curr.Name, "Permissions at same level should be sorted by name")
} else {
assert.True(t, prev.Level <= curr.Level, "Permissions should be sorted by level")
}
}
// Verify hierarchy structure
for _, perm := range permissions {
if perm.Parent != "" {
// Find parent permission
parentFound := false
for _, parent := range permissions {
if parent.Name == perm.Parent {
parentFound = true
assert.True(t, parent.Level < perm.Level, "Parent should have lower level than child")
assert.Contains(t, parent.Children, perm.Name, "Parent should contain child in children list")
break
}
}
assert.True(t, parentFound, "Parent permission should exist for %s", perm.Name)
}
}
}
func TestPermissionManager_ListRoles(t *testing.T) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
roles := pm.ListRoles()
// Should have default roles
assert.NotEmpty(t, roles)
// Should be sorted by name
for i := 1; i < len(roles); i++ {
assert.True(t, roles[i-1].Name <= roles[i].Name, "Roles should be sorted by name")
}
// Verify all permissions in roles exist
allPermissions := pm.ListPermissions()
permissionNames := make(map[string]bool)
for _, perm := range allPermissions {
permissionNames[perm.Name] = true
}
for _, role := range roles {
for _, perm := range role.Permissions {
assert.True(t, permissionNames[perm], "Role %s references non-existent permission %s", role.Name, perm)
}
}
}
func TestPermissionManager_InvalidatePermissionCache(t *testing.T) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
ctx := context.Background()
err := pm.InvalidatePermissionCache(ctx, "user123", "app123")
// Should not error (currently just logs)
assert.NoError(t, err)
}
func TestPermissionHierarchy_BuildHierarchy(t *testing.T) {
hierarchy := auth.NewPermissionHierarchy()
// Test that parent-child relationships are built correctly
permissions := hierarchy.ListPermissions()
// Find admin permission
var adminPerm *auth.Permission
for _, perm := range permissions {
if perm.Name == "admin" {
adminPerm = perm
break
}
}
require.NotNil(t, adminPerm, "Admin permission should exist")
// Admin should have children
assert.NotEmpty(t, adminPerm.Children, "Admin permission should have children")
// Check that app.admin is a child of admin
assert.Contains(t, adminPerm.Children, "app.admin", "app.admin should be a child of admin")
// Find app.write permission
var appWritePerm *auth.Permission
for _, perm := range permissions {
if perm.Name == "app.write" {
appWritePerm = perm
break
}
}
require.NotNil(t, appWritePerm, "app.write permission should exist")
// app.write should have children
assert.NotEmpty(t, appWritePerm.Children, "app.write permission should have children")
assert.Contains(t, appWritePerm.Children, "app.create", "app.create should be a child of app.write")
assert.Contains(t, appWritePerm.Children, "app.update", "app.update should be a child of app.write")
assert.Contains(t, appWritePerm.Children, "app.delete", "app.delete should be a child of app.write")
}
// Benchmark tests for permission operations
func BenchmarkPermissionManager_HasPermission(b *testing.B) {
configMock := NewTestConfig()
configMock.values["CACHE_ENABLED"] = "false"
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := pm.HasPermission(ctx, "dev@example.com", "test-app", "token.create")
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkPermissionManager_EvaluateBulkPermissions(b *testing.B) {
configMock := NewTestConfig()
configMock.values["CACHE_ENABLED"] = "false"
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
ctx := context.Background()
req := &auth.BulkPermissionRequest{
UserID: "dev@example.com",
AppID: "test-app",
Permissions: []string{
"app.read", "token.create", "token.read", "app.delete", "admin",
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := pm.EvaluateBulkPermissions(ctx, req)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkPermissionManager_ListPermissions(b *testing.B) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
permissions := pm.ListPermissions()
if len(permissions) == 0 {
b.Fatal("No permissions returned")
}
}
}
func BenchmarkPermissionManager_ListRoles(b *testing.B) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
roles := pm.ListRoles()
if len(roles) == 0 {
b.Fatal("No roles returned")
}
}
}
// Test permission hierarchy traversal
func TestPermissionHierarchy_PermissionInheritance(t *testing.T) {
configMock := NewTestConfig()
configMock.values["CACHE_ENABLED"] = "false"
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
// Test that admin users get hierarchical permissions
ctx := context.Background()
// Admin should have all permissions through hierarchy
adminPermissions := []string{
"admin",
"app.admin",
"token.admin",
"permission.admin",
"user.admin",
}
for _, perm := range adminPermissions {
evaluation, err := pm.HasPermission(ctx, "admin@example.com", "test-app", perm)
require.NoError(t, err)
assert.True(t, evaluation.Granted, "Admin should have %s permission", perm)
}
}
// Test role inheritance
func TestPermissionManager_RoleInheritance(t *testing.T) {
configMock := NewTestConfig()
logger := zap.NewNop()
pm := auth.NewPermissionManager(configMock, logger)
// Add a role that inherits from another role
parentRole := &auth.Role{
Name: "base_role",
Description: "Base role with basic permissions",
Permissions: []string{"read", "app.read"},
Metadata: map[string]string{"level": "base"},
}
childRole := &auth.Role{
Name: "extended_role",
Description: "Extended role that inherits from base",
Permissions: []string{"write"},
Inherits: []string{"base_role"},
Metadata: map[string]string{"level": "extended"},
}
err := pm.AddRole(parentRole)
require.NoError(t, err)
err = pm.AddRole(childRole)
require.NoError(t, err)
// Verify roles were added
roles := pm.ListRoles()
roleNames := make(map[string]*auth.Role)
for _, role := range roles {
roleNames[role.Name] = role
}
assert.Contains(t, roleNames, "base_role")
assert.Contains(t, roleNames, "extended_role")
assert.Equal(t, []string{"base_role"}, roleNames["extended_role"].Inherits)
}

View File

@ -29,6 +29,10 @@ func (c *TestConfig) GetBool(key string) bool {
return boolVal
}
}
// Special handling for cache enabled
if key == "CACHE_ENABLED" {
return c.values[key] == "true"
}
return false
}
@ -86,6 +90,10 @@ func (c *TestConfig) IsProduction() bool {
return c.GetString("APP_ENV") == "production"
}
func (c *TestConfig) GetJWTSecret() string {
return c.GetString("JWT_SECRET")
}
// NewTestConfig creates a test configuration with default values
func NewTestConfig() *TestConfig {
return &TestConfig{
@ -99,6 +107,12 @@ func NewTestConfig() *TestConfig {
"SERVER_HOST": "localhost",
"SERVER_PORT": "8080",
"APP_ENV": "test",
"JWT_SECRET": "test-jwt-secret-for-testing-only",
},
}
}
// NewMockConfig creates a mock configuration (alias for NewTestConfig for backward compatibility)
func NewMockConfig() *TestConfig {
return NewTestConfig()
}