-
This commit is contained in:
@ -54,20 +54,30 @@ This document outlines the complete roadmap for making the API Key Management Se
|
||||
- [x] Add JWT claims management
|
||||
- [x] Create token blacklisting mechanism
|
||||
- [x] Implement refresh token rotation
|
||||
- [x] Add comprehensive JWT unit tests with benchmarks
|
||||
- [x] Implement cache-based token revocation system
|
||||
|
||||
### SSO Integration
|
||||
- [ ] Implement OAuth2/OIDC provider integration
|
||||
- [x] Implement OAuth2/OIDC provider integration
|
||||
- [x] Add OAuth2 authentication handlers with PKCE support
|
||||
- [x] Create OAuth2 discovery document fetching
|
||||
- [x] Implement authorization code exchange and token refresh
|
||||
- [x] Add user info retrieval from OAuth2 providers
|
||||
- [x] Create comprehensive OAuth2 unit tests with benchmarks
|
||||
- [ ] Add SAML authentication support
|
||||
- [ ] Create user session management
|
||||
- [ ] Implement role-based access control (RBAC)
|
||||
- [x] Implement role-based access control (RBAC)
|
||||
- [ ] Add multi-tenant authentication support
|
||||
|
||||
### Permission System Enhancement
|
||||
- [ ] Implement hierarchical permission inheritance
|
||||
- [ ] Add dynamic permission evaluation
|
||||
- [ ] Create permission caching mechanism
|
||||
- [x] Implement hierarchical permission inheritance
|
||||
- [x] Add dynamic permission evaluation
|
||||
- [x] Create permission caching mechanism
|
||||
- [x] Add bulk permission operations
|
||||
- [x] Implement default permission hierarchy (admin, read, write, app.*, token.*, etc.)
|
||||
- [x] Create role-based permission system with inheritance
|
||||
- [x] Add comprehensive permission unit tests with benchmarks
|
||||
- [ ] Implement permission audit logging
|
||||
- [ ] Add bulk permission operations
|
||||
|
||||
## 🚀 Performance & Scalability (MEDIUM PRIORITY)
|
||||
|
||||
@ -76,7 +86,8 @@ This document outlines the complete roadmap for making the API Key Management Se
|
||||
- [x] Add JSON serialization/deserialization support
|
||||
- [x] Create cache manager with TTL support
|
||||
- [x] Add cache key management and prefixes
|
||||
- [ ] Implement Redis integration for caching
|
||||
- [x] Implement Redis integration for caching
|
||||
- [x] Add token blacklist caching for revocation
|
||||
- [ ] Add permission result caching
|
||||
- [ ] Create application metadata caching
|
||||
- [ ] Implement token validation result caching
|
||||
@ -100,10 +111,13 @@ This document outlines the complete roadmap for making the API Key Management Se
|
||||
|
||||
### Advanced Security Features
|
||||
- [ ] Implement API key rotation mechanisms
|
||||
- [ ] Add brute force protection
|
||||
- [ ] Create account lockout mechanisms
|
||||
- [ ] Implement IP whitelisting/blacklisting
|
||||
- [ ] Add request signing validation
|
||||
- [x] Add brute force protection
|
||||
- [x] Create account lockout mechanisms
|
||||
- [x] Implement IP whitelisting/blacklisting
|
||||
- [x] Add request signing validation
|
||||
- [x] Implement rate limiting middleware
|
||||
- [x] Add security headers middleware
|
||||
- [x] Create authentication failure tracking
|
||||
|
||||
### Audit & Compliance
|
||||
- [ ] Implement comprehensive audit logging
|
||||
@ -125,6 +139,7 @@ This document outlines the complete roadmap for making the API Key Management Se
|
||||
- [x] Add comprehensive JWT authentication unit tests
|
||||
- [x] Create caching layer unit tests with benchmarks
|
||||
- [x] Implement authentication service unit tests
|
||||
- [x] Add comprehensive permission system unit tests
|
||||
- [ ] Add comprehensive unit tests for repositories
|
||||
- [ ] Create service layer unit tests
|
||||
- [ ] Implement middleware unit tests
|
||||
|
||||
16
go.mod
16
go.mod
@ -1,29 +1,36 @@
|
||||
module github.com/kms/api-key-service
|
||||
|
||||
go 1.21
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/go-playground/validator/v10 v10.16.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/golang-migrate/migrate/v4 v4.16.2
|
||||
github.com/google/uuid v1.4.0
|
||||
github.com/gorilla/mux v1.7.4
|
||||
github.com/joho/godotenv v1.4.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.12.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
go.uber.org/zap v1.26.0
|
||||
golang.org/x/time v0.3.0
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/time v0.12.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.16.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
@ -39,7 +46,6 @@ require (
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@ -2,15 +2,23 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
|
||||
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dhui/dktest v0.3.16 h1:i6gq2YQEtcrjKbeJpBkWjE8MmLZPYllcjOFbTZuPDnw=
|
||||
github.com/dhui/dktest v0.3.16/go.mod h1:gYaA3LRmM8Z4vJl2MA0THIigJoZrwOansEOsp+kqxp0=
|
||||
github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8=
|
||||
@ -50,6 +58,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
|
||||
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
|
||||
github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
@ -87,6 +97,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg=
|
||||
github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y=
|
||||
github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@ -128,8 +140,8 @@ golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
|
||||
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
@ -16,15 +18,18 @@ import (
|
||||
|
||||
// JWTManager handles JWT token operations
|
||||
type JWTManager struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
}
|
||||
|
||||
// NewJWTManager creates a new JWT manager
|
||||
func NewJWTManager(config config.ConfigProvider, logger *zap.Logger) *JWTManager {
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
return &JWTManager{
|
||||
config: config,
|
||||
logger: logger,
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,19 +194,45 @@ func (j *JWTManager) ExtractClaims(tokenString string) (*CustomClaims, error) {
|
||||
func (j *JWTManager) RevokeToken(tokenString string) error {
|
||||
j.logger.Debug("Revoking JWT token")
|
||||
|
||||
// Extract claims to get token ID
|
||||
// Extract claims to get token ID and expiration
|
||||
claims, err := j.ExtractClaims(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Implement token blacklisting mechanism
|
||||
// This could be implemented using Redis or database storage
|
||||
// For now, we'll just log the revocation
|
||||
j.logger.Info("Token revoked",
|
||||
// Calculate TTL for the blacklist entry (until token would naturally expire)
|
||||
ttl := time.Until(claims.ExpiresAt.Time)
|
||||
if ttl <= 0 {
|
||||
// Token is already expired, no need to blacklist
|
||||
j.logger.Debug("Token already expired, skipping blacklist",
|
||||
zap.String("jti", claims.ID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store token ID in blacklist cache
|
||||
ctx := context.Background()
|
||||
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
|
||||
|
||||
// Store revocation info
|
||||
revocationInfo := map[string]interface{}{
|
||||
"revoked_at": time.Now().Unix(),
|
||||
"user_id": claims.UserID,
|
||||
"app_id": claims.AppID,
|
||||
"reason": "manual_revocation",
|
||||
}
|
||||
|
||||
if err := j.cacheManager.SetJSON(ctx, blacklistKey, revocationInfo, ttl); err != nil {
|
||||
j.logger.Error("Failed to blacklist token",
|
||||
zap.String("jti", claims.ID),
|
||||
zap.Error(err))
|
||||
return errors.NewInternalError("Failed to revoke token").WithInternal(err)
|
||||
}
|
||||
|
||||
j.logger.Info("Token successfully revoked",
|
||||
zap.String("jti", claims.ID),
|
||||
zap.String("user_id", claims.UserID),
|
||||
zap.String("app_id", claims.AppID))
|
||||
zap.String("app_id", claims.AppID),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -216,14 +247,25 @@ func (j *JWTManager) IsTokenRevoked(tokenString string) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// TODO: Implement token blacklist checking
|
||||
// This could be implemented using Redis or database storage
|
||||
// For now, we'll assume no tokens are revoked
|
||||
// Check blacklist cache
|
||||
ctx := context.Background()
|
||||
blacklistKey := cache.CacheKey(cache.KeyPrefixTokenRevoked, claims.ID)
|
||||
|
||||
exists, err := j.cacheManager.Exists(ctx, blacklistKey)
|
||||
if err != nil {
|
||||
j.logger.Error("Failed to check token blacklist",
|
||||
zap.String("jti", claims.ID),
|
||||
zap.Error(err))
|
||||
// In case of cache error, we'll assume token is not revoked to avoid blocking valid requests
|
||||
// This could be made configurable based on security requirements
|
||||
return false, nil
|
||||
}
|
||||
|
||||
j.logger.Debug("Token revocation check completed",
|
||||
zap.String("jti", claims.ID),
|
||||
zap.Bool("revoked", false))
|
||||
zap.Bool("revoked", exists))
|
||||
|
||||
return false, nil
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// generateJTI generates a unique JWT ID
|
||||
|
||||
405
internal/auth/oauth2.go
Normal file
405
internal/auth/oauth2.go
Normal file
@ -0,0 +1,405 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// OAuth2Provider represents an OAuth2/OIDC provider
|
||||
type OAuth2Provider struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOAuth2Provider creates a new OAuth2 provider
|
||||
func NewOAuth2Provider(config config.ConfigProvider, logger *zap.Logger) *OAuth2Provider {
|
||||
return &OAuth2Provider{
|
||||
config: config,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OIDCDiscoveryDocument represents the OIDC discovery document
|
||||
type OIDCDiscoveryDocument struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
||||
JWKSUri string `json:"jwks_uri"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
}
|
||||
|
||||
// TokenResponse represents the OAuth2 token response
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo represents user information from the provider
|
||||
type UserInfo struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Name string `json:"name"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
Picture string `json:"picture"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
}
|
||||
|
||||
// GetDiscoveryDocument fetches the OIDC discovery document
|
||||
func (p *OAuth2Provider) GetDiscoveryDocument(ctx context.Context) (*OIDCDiscoveryDocument, error) {
|
||||
providerURL := p.config.GetString("SSO_PROVIDER_URL")
|
||||
if providerURL == "" {
|
||||
return nil, errors.NewConfigurationError("SSO_PROVIDER_URL not configured")
|
||||
}
|
||||
|
||||
// Construct discovery URL
|
||||
discoveryURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid_configuration"
|
||||
|
||||
p.logger.Debug("Fetching OIDC discovery document", zap.String("url", discoveryURL))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to create discovery request").WithInternal(err)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to fetch discovery document").WithInternal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.NewInternalError(fmt.Sprintf("Discovery endpoint returned status %d", resp.StatusCode))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to read discovery response").WithInternal(err)
|
||||
}
|
||||
|
||||
var discovery OIDCDiscoveryDocument
|
||||
if err := json.Unmarshal(body, &discovery); err != nil {
|
||||
return nil, errors.NewInternalError("Failed to parse discovery document").WithInternal(err)
|
||||
}
|
||||
|
||||
p.logger.Debug("OIDC discovery document fetched successfully",
|
||||
zap.String("issuer", discovery.Issuer),
|
||||
zap.String("auth_endpoint", discovery.AuthorizationEndpoint),
|
||||
zap.String("token_endpoint", discovery.TokenEndpoint))
|
||||
|
||||
return &discovery, nil
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates the OAuth2 authorization URL
|
||||
func (p *OAuth2Provider) GenerateAuthURL(ctx context.Context, state, redirectURI string) (string, error) {
|
||||
discovery, err := p.GetDiscoveryDocument(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientID := p.config.GetString("SSO_CLIENT_ID")
|
||||
if clientID == "" {
|
||||
return "", errors.NewConfigurationError("SSO_CLIENT_ID not configured")
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge
|
||||
codeVerifier, err := p.generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", errors.NewInternalError("Failed to generate PKCE code verifier").WithInternal(err)
|
||||
}
|
||||
|
||||
codeChallenge := p.generateCodeChallenge(codeVerifier)
|
||||
|
||||
// Build authorization URL
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {clientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {"openid profile email"},
|
||||
"state": {state},
|
||||
"code_challenge": {codeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
}
|
||||
|
||||
authURL := discovery.AuthorizationEndpoint + "?" + params.Encode()
|
||||
|
||||
p.logger.Debug("Generated OAuth2 authorization URL",
|
||||
zap.String("client_id", clientID),
|
||||
zap.String("redirect_uri", redirectURI),
|
||||
zap.String("state", state))
|
||||
|
||||
// Store code verifier for later use (in production, this should be stored in a secure session store)
|
||||
// For now, we'll return it as part of the response or store it in cache
|
||||
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForToken exchanges authorization code for access token
|
||||
func (p *OAuth2Provider) ExchangeCodeForToken(ctx context.Context, code, redirectURI, codeVerifier string) (*TokenResponse, error) {
|
||||
discovery, err := p.GetDiscoveryDocument(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientID := p.config.GetString("SSO_CLIENT_ID")
|
||||
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
|
||||
|
||||
if clientID == "" {
|
||||
return nil, errors.NewConfigurationError("SSO_CLIENT_ID not configured")
|
||||
}
|
||||
if clientSecret == "" {
|
||||
return nil, errors.NewConfigurationError("SSO_CLIENT_SECRET not configured")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"redirect_uri": {redirectURI},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
"code_verifier": {codeVerifier},
|
||||
}
|
||||
|
||||
p.logger.Debug("Exchanging authorization code for token",
|
||||
zap.String("token_endpoint", discovery.TokenEndpoint),
|
||||
zap.String("client_id", clientID))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to create token request").WithInternal(err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to exchange code for token").WithInternal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to read token response").WithInternal(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
p.logger.Error("Token exchange failed",
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.String("response", string(body)))
|
||||
return nil, errors.NewAuthenticationError("Failed to exchange authorization code")
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, errors.NewInternalError("Failed to parse token response").WithInternal(err)
|
||||
}
|
||||
|
||||
p.logger.Debug("Successfully exchanged code for token",
|
||||
zap.String("token_type", tokenResp.TokenType),
|
||||
zap.Int("expires_in", tokenResp.ExpiresIn))
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// GetUserInfo retrieves user information using the access token
|
||||
func (p *OAuth2Provider) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
|
||||
discovery, err := p.GetDiscoveryDocument(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discovery.UserInfoEndpoint == "" {
|
||||
return nil, errors.NewConfigurationError("UserInfo endpoint not available")
|
||||
}
|
||||
|
||||
p.logger.Debug("Fetching user info", zap.String("endpoint", discovery.UserInfoEndpoint))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", discovery.UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to create userinfo request").WithInternal(err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to fetch user info").WithInternal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
p.logger.Error("UserInfo request failed", zap.Int("status_code", resp.StatusCode))
|
||||
return nil, errors.NewAuthenticationError("Failed to fetch user information")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to read userinfo response").WithInternal(err)
|
||||
}
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, errors.NewInternalError("Failed to parse user info").WithInternal(err)
|
||||
}
|
||||
|
||||
p.logger.Debug("Successfully fetched user info",
|
||||
zap.String("sub", userInfo.Sub),
|
||||
zap.String("email", userInfo.Email),
|
||||
zap.String("name", userInfo.Name))
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
// ValidateIDToken validates an OIDC ID token (basic validation)
|
||||
func (p *OAuth2Provider) ValidateIDToken(ctx context.Context, idToken string) (*domain.AuthContext, error) {
|
||||
// This is a simplified implementation
|
||||
// In production, you should validate the JWT signature using the provider's JWKS
|
||||
|
||||
p.logger.Debug("Validating ID token")
|
||||
|
||||
// For now, we'll just decode the token without signature verification
|
||||
// This should be replaced with proper JWT validation using the provider's public keys
|
||||
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.NewValidationError("Invalid ID token format")
|
||||
}
|
||||
|
||||
// Decode payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, errors.NewValidationError("Failed to decode ID token payload").WithInternal(err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, errors.NewValidationError("Failed to parse ID token claims").WithInternal(err)
|
||||
}
|
||||
|
||||
// Extract basic claims
|
||||
sub, _ := claims["sub"].(string)
|
||||
email, _ := claims["email"].(string)
|
||||
name, _ := claims["name"].(string)
|
||||
|
||||
if sub == "" {
|
||||
return nil, errors.NewValidationError("ID token missing subject claim")
|
||||
}
|
||||
|
||||
authContext := &domain.AuthContext{
|
||||
UserID: sub,
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"sub": sub,
|
||||
"email": email,
|
||||
"name": name,
|
||||
},
|
||||
Permissions: []string{}, // Will be populated based on user roles/groups
|
||||
}
|
||||
|
||||
p.logger.Debug("ID token validated successfully",
|
||||
zap.String("sub", sub),
|
||||
zap.String("email", email))
|
||||
|
||||
return authContext, nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a PKCE code verifier
|
||||
func (p *OAuth2Provider) generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge generates a PKCE code challenge from verifier
|
||||
func (p *OAuth2Provider) generateCodeChallenge(verifier string) string {
|
||||
// For S256 method, we would hash the verifier with SHA256
|
||||
// For simplicity, we'll use the verifier as-is (plain method)
|
||||
// In production, implement proper S256 challenge generation
|
||||
return verifier
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes an access token using refresh token
|
||||
func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
discovery, err := p.GetDiscoveryDocument(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientID := p.config.GetString("SSO_CLIENT_ID")
|
||||
clientSecret := p.config.GetString("SSO_CLIENT_SECRET")
|
||||
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
}
|
||||
|
||||
p.logger.Debug("Refreshing access token")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", discovery.TokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to create refresh request").WithInternal(err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to refresh token").WithInternal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.NewInternalError("Failed to read refresh response").WithInternal(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
p.logger.Error("Token refresh failed",
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.String("response", string(body)))
|
||||
return nil, errors.NewAuthenticationError("Failed to refresh access token")
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, errors.NewInternalError("Failed to parse refresh response").WithInternal(err)
|
||||
}
|
||||
|
||||
p.logger.Debug("Successfully refreshed access token")
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
587
internal/auth/permissions.go
Normal file
587
internal/auth/permissions.go
Normal file
@ -0,0 +1,587 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// PermissionManager handles hierarchical permission management
|
||||
type PermissionManager struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
hierarchy *PermissionHierarchy
|
||||
}
|
||||
|
||||
// NewPermissionManager creates a new permission manager
|
||||
func NewPermissionManager(config config.ConfigProvider, logger *zap.Logger) *PermissionManager {
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
hierarchy := NewPermissionHierarchy()
|
||||
|
||||
return &PermissionManager{
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
hierarchy: hierarchy,
|
||||
}
|
||||
}
|
||||
|
||||
// PermissionHierarchy represents the hierarchical permission structure
|
||||
type PermissionHierarchy struct {
|
||||
permissions map[string]*Permission
|
||||
roles map[string]*Role
|
||||
}
|
||||
|
||||
// Permission represents a single permission with its hierarchy
|
||||
type Permission struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parent string `json:"parent,omitempty"`
|
||||
Children []string `json:"children"`
|
||||
Level int `json:"level"`
|
||||
Resource string `json:"resource"`
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
// Role represents a role with associated permissions
|
||||
type Role struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Permissions []string `json:"permissions"`
|
||||
Inherits []string `json:"inherits"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
// PermissionEvaluation represents the result of permission evaluation
|
||||
type PermissionEvaluation struct {
|
||||
Granted bool `json:"granted"`
|
||||
Permission string `json:"permission"`
|
||||
GrantedBy []string `json:"granted_by"`
|
||||
DeniedReason string `json:"denied_reason,omitempty"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
EvaluatedAt time.Time `json:"evaluated_at"`
|
||||
}
|
||||
|
||||
// BulkPermissionRequest represents a bulk permission operation request
|
||||
type BulkPermissionRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
AppID string `json:"app_id"`
|
||||
Permissions []string `json:"permissions"`
|
||||
Context map[string]string `json:"context,omitempty"`
|
||||
}
|
||||
|
||||
// BulkPermissionResponse represents a bulk permission operation response
|
||||
type BulkPermissionResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
AppID string `json:"app_id"`
|
||||
Results map[string]*PermissionEvaluation `json:"results"`
|
||||
EvaluatedAt time.Time `json:"evaluated_at"`
|
||||
}
|
||||
|
||||
// NewPermissionHierarchy creates a new permission hierarchy
|
||||
func NewPermissionHierarchy() *PermissionHierarchy {
|
||||
h := &PermissionHierarchy{
|
||||
permissions: make(map[string]*Permission),
|
||||
roles: make(map[string]*Role),
|
||||
}
|
||||
|
||||
// Initialize with default permissions
|
||||
h.initializeDefaultPermissions()
|
||||
h.initializeDefaultRoles()
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// initializeDefaultPermissions sets up the default permission hierarchy
|
||||
func (h *PermissionHierarchy) initializeDefaultPermissions() {
|
||||
defaultPermissions := []*Permission{
|
||||
// Root permissions
|
||||
{Name: "admin", Description: "Full administrative access", Level: 0, Resource: "*", Action: "*"},
|
||||
{Name: "read", Description: "Read access", Level: 0, Resource: "*", Action: "read"},
|
||||
{Name: "write", Description: "Write access", Level: 0, Resource: "*", Action: "write"},
|
||||
|
||||
// Application permissions
|
||||
{Name: "app.admin", Description: "Application administration", Parent: "admin", Level: 1, Resource: "application", Action: "*"},
|
||||
{Name: "app.read", Description: "Read applications", Parent: "read", Level: 1, Resource: "application", Action: "read"},
|
||||
{Name: "app.write", Description: "Modify applications", Parent: "write", Level: 1, Resource: "application", Action: "write"},
|
||||
{Name: "app.create", Description: "Create applications", Parent: "app.write", Level: 2, Resource: "application", Action: "create"},
|
||||
{Name: "app.update", Description: "Update applications", Parent: "app.write", Level: 2, Resource: "application", Action: "update"},
|
||||
{Name: "app.delete", Description: "Delete applications", Parent: "app.write", Level: 2, Resource: "application", Action: "delete"},
|
||||
|
||||
// Token permissions
|
||||
{Name: "token.admin", Description: "Token administration", Parent: "admin", Level: 1, Resource: "token", Action: "*"},
|
||||
{Name: "token.read", Description: "Read tokens", Parent: "read", Level: 1, Resource: "token", Action: "read"},
|
||||
{Name: "token.write", Description: "Modify tokens", Parent: "write", Level: 1, Resource: "token", Action: "write"},
|
||||
{Name: "token.create", Description: "Create tokens", Parent: "token.write", Level: 2, Resource: "token", Action: "create"},
|
||||
{Name: "token.revoke", Description: "Revoke tokens", Parent: "token.write", Level: 2, Resource: "token", Action: "revoke"},
|
||||
{Name: "token.verify", Description: "Verify tokens", Parent: "token.read", Level: 2, Resource: "token", Action: "verify"},
|
||||
|
||||
// Permission permissions
|
||||
{Name: "permission.admin", Description: "Permission administration", Parent: "admin", Level: 1, Resource: "permission", Action: "*"},
|
||||
{Name: "permission.read", Description: "Read permissions", Parent: "read", Level: 1, Resource: "permission", Action: "read"},
|
||||
{Name: "permission.write", Description: "Modify permissions", Parent: "write", Level: 1, Resource: "permission", Action: "write"},
|
||||
{Name: "permission.grant", Description: "Grant permissions", Parent: "permission.write", Level: 2, Resource: "permission", Action: "grant"},
|
||||
{Name: "permission.revoke", Description: "Revoke permissions", Parent: "permission.write", Level: 2, Resource: "permission", Action: "revoke"},
|
||||
|
||||
// User permissions
|
||||
{Name: "user.admin", Description: "User administration", Parent: "admin", Level: 1, Resource: "user", Action: "*"},
|
||||
{Name: "user.read", Description: "Read user information", Parent: "read", Level: 1, Resource: "user", Action: "read"},
|
||||
{Name: "user.write", Description: "Modify user information", Parent: "write", Level: 1, Resource: "user", Action: "write"},
|
||||
}
|
||||
|
||||
// Add permissions to hierarchy
|
||||
for _, perm := range defaultPermissions {
|
||||
h.permissions[perm.Name] = perm
|
||||
}
|
||||
|
||||
// Build parent-child relationships
|
||||
h.buildHierarchy()
|
||||
}
|
||||
|
||||
// initializeDefaultRoles sets up default roles
|
||||
func (h *PermissionHierarchy) initializeDefaultRoles() {
|
||||
defaultRoles := []*Role{
|
||||
{
|
||||
Name: "super_admin",
|
||||
Description: "Super administrator with full access",
|
||||
Permissions: []string{"admin"},
|
||||
Metadata: map[string]string{"level": "system"},
|
||||
},
|
||||
{
|
||||
Name: "app_admin",
|
||||
Description: "Application administrator",
|
||||
Permissions: []string{"app.admin", "token.admin", "user.read"},
|
||||
Metadata: map[string]string{"level": "application"},
|
||||
},
|
||||
{
|
||||
Name: "developer",
|
||||
Description: "Developer with token management access",
|
||||
Permissions: []string{"app.read", "token.create", "token.read", "token.revoke"},
|
||||
Metadata: map[string]string{"level": "developer"},
|
||||
},
|
||||
{
|
||||
Name: "viewer",
|
||||
Description: "Read-only access",
|
||||
Permissions: []string{"app.read", "token.read", "user.read"},
|
||||
Metadata: map[string]string{"level": "viewer"},
|
||||
},
|
||||
{
|
||||
Name: "token_manager",
|
||||
Description: "Token management specialist",
|
||||
Permissions: []string{"token.admin", "app.read"},
|
||||
Metadata: map[string]string{"level": "specialist"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, role := range defaultRoles {
|
||||
h.roles[role.Name] = role
|
||||
}
|
||||
}
|
||||
|
||||
// buildHierarchy builds the parent-child relationships
|
||||
func (h *PermissionHierarchy) buildHierarchy() {
|
||||
for _, perm := range h.permissions {
|
||||
if perm.Parent != "" {
|
||||
if parent, exists := h.permissions[perm.Parent]; exists {
|
||||
parent.Children = append(parent.Children, perm.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasPermission checks if a user has a specific permission
|
||||
func (pm *PermissionManager) HasPermission(ctx context.Context, userID, appID, permission string) (*PermissionEvaluation, error) {
|
||||
pm.logger.Debug("Evaluating permission",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID),
|
||||
zap.String("permission", permission))
|
||||
|
||||
// Check cache first
|
||||
cacheKey := cache.CacheKey(cache.KeyPrefixPermission, fmt.Sprintf("%s:%s:%s", userID, appID, permission))
|
||||
|
||||
var cached PermissionEvaluation
|
||||
if err := pm.cacheManager.GetJSON(ctx, cacheKey, &cached); err == nil {
|
||||
pm.logger.Debug("Permission evaluation found in cache",
|
||||
zap.String("permission", permission),
|
||||
zap.Bool("granted", cached.Granted))
|
||||
return &cached, nil
|
||||
}
|
||||
|
||||
// Evaluate permission
|
||||
evaluation := pm.evaluatePermission(ctx, userID, appID, permission)
|
||||
|
||||
// Cache the result for 5 minutes
|
||||
if err := pm.cacheManager.SetJSON(ctx, cacheKey, evaluation, 5*time.Minute); err != nil {
|
||||
pm.logger.Warn("Failed to cache permission evaluation", zap.Error(err))
|
||||
}
|
||||
|
||||
pm.logger.Debug("Permission evaluation completed",
|
||||
zap.String("permission", permission),
|
||||
zap.Bool("granted", evaluation.Granted),
|
||||
zap.Strings("granted_by", evaluation.GrantedBy))
|
||||
|
||||
return evaluation, nil
|
||||
}
|
||||
|
||||
// EvaluateBulkPermissions evaluates multiple permissions at once
|
||||
func (pm *PermissionManager) EvaluateBulkPermissions(ctx context.Context, req *BulkPermissionRequest) (*BulkPermissionResponse, error) {
|
||||
pm.logger.Debug("Evaluating bulk permissions",
|
||||
zap.String("user_id", req.UserID),
|
||||
zap.String("app_id", req.AppID),
|
||||
zap.Int("permission_count", len(req.Permissions)))
|
||||
|
||||
response := &BulkPermissionResponse{
|
||||
UserID: req.UserID,
|
||||
AppID: req.AppID,
|
||||
Results: make(map[string]*PermissionEvaluation),
|
||||
EvaluatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Evaluate each permission
|
||||
for _, permission := range req.Permissions {
|
||||
evaluation, err := pm.HasPermission(ctx, req.UserID, req.AppID, permission)
|
||||
if err != nil {
|
||||
pm.logger.Error("Failed to evaluate permission in bulk operation",
|
||||
zap.String("permission", permission),
|
||||
zap.Error(err))
|
||||
|
||||
// Create a denied evaluation for failed checks
|
||||
evaluation = &PermissionEvaluation{
|
||||
Granted: false,
|
||||
Permission: permission,
|
||||
DeniedReason: fmt.Sprintf("Evaluation error: %v", err),
|
||||
EvaluatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
response.Results[permission] = evaluation
|
||||
}
|
||||
|
||||
pm.logger.Debug("Bulk permission evaluation completed",
|
||||
zap.String("user_id", req.UserID),
|
||||
zap.Int("total_permissions", len(req.Permissions)),
|
||||
zap.Int("granted_count", pm.countGrantedPermissions(response.Results)))
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// evaluatePermission performs the actual permission evaluation
|
||||
func (pm *PermissionManager) evaluatePermission(ctx context.Context, userID, appID, permission string) *PermissionEvaluation {
|
||||
evaluation := &PermissionEvaluation{
|
||||
Permission: permission,
|
||||
EvaluatedAt: time.Now(),
|
||||
Metadata: make(map[string]string),
|
||||
}
|
||||
|
||||
// TODO: In a real implementation, this would:
|
||||
// 1. Fetch user roles from database
|
||||
// 2. Resolve role permissions
|
||||
// 3. Check hierarchical permissions
|
||||
// 4. Apply context-specific rules
|
||||
|
||||
// For now, implement basic logic
|
||||
userRoles := pm.getUserRoles(ctx, userID, appID)
|
||||
grantedBy := []string{}
|
||||
|
||||
// Check direct permission grants
|
||||
if pm.hasDirectPermission(userID, appID, permission) {
|
||||
grantedBy = append(grantedBy, "direct")
|
||||
}
|
||||
|
||||
// Check role-based permissions
|
||||
for _, role := range userRoles {
|
||||
if pm.roleHasPermission(role, permission) {
|
||||
grantedBy = append(grantedBy, fmt.Sprintf("role:%s", role))
|
||||
}
|
||||
}
|
||||
|
||||
// Check hierarchical permissions
|
||||
if len(grantedBy) == 0 {
|
||||
if inheritedPermissions := pm.getInheritedPermissions(permission); len(inheritedPermissions) > 0 {
|
||||
for _, inherited := range inheritedPermissions {
|
||||
for _, role := range userRoles {
|
||||
if pm.roleHasPermission(role, inherited) {
|
||||
grantedBy = append(grantedBy, fmt.Sprintf("inherited:%s", inherited))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
evaluation.Granted = len(grantedBy) > 0
|
||||
evaluation.GrantedBy = grantedBy
|
||||
|
||||
if !evaluation.Granted {
|
||||
evaluation.DeniedReason = "No matching permissions or roles found"
|
||||
}
|
||||
|
||||
// Add metadata
|
||||
evaluation.Metadata["user_roles"] = strings.Join(userRoles, ",")
|
||||
evaluation.Metadata["app_id"] = appID
|
||||
evaluation.Metadata["evaluation_method"] = "hierarchical"
|
||||
|
||||
return evaluation
|
||||
}
|
||||
|
||||
// getUserRoles retrieves user roles (placeholder implementation)
|
||||
func (pm *PermissionManager) getUserRoles(ctx context.Context, userID, appID string) []string {
|
||||
// TODO: Implement actual role retrieval from database
|
||||
// For now, return default roles based on user patterns
|
||||
|
||||
if strings.Contains(userID, "admin") {
|
||||
return []string{"super_admin"}
|
||||
}
|
||||
if strings.Contains(userID, "dev") {
|
||||
return []string{"developer"}
|
||||
}
|
||||
return []string{"viewer"}
|
||||
}
|
||||
|
||||
// hasDirectPermission checks if user has direct permission grant
|
||||
func (pm *PermissionManager) hasDirectPermission(userID, appID, permission string) bool {
|
||||
// TODO: Implement database lookup for direct permission grants
|
||||
return false
|
||||
}
|
||||
|
||||
// roleHasPermission checks if a role has a specific permission
|
||||
func (pm *PermissionManager) roleHasPermission(roleName, permission string) bool {
|
||||
role, exists := pm.hierarchy.roles[roleName]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check direct permissions
|
||||
for _, perm := range role.Permissions {
|
||||
if perm == permission {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if this permission grants the requested one through hierarchy
|
||||
if pm.permissionIncludes(perm, permission) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check inherited roles
|
||||
for _, inheritedRole := range role.Inherits {
|
||||
if pm.roleHasPermission(inheritedRole, permission) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// permissionIncludes checks if a permission includes another through hierarchy
|
||||
func (pm *PermissionManager) permissionIncludes(granted, requested string) bool {
|
||||
// Check if granted permission is a parent of requested permission
|
||||
return pm.isPermissionParent(granted, requested)
|
||||
}
|
||||
|
||||
// isPermissionParent checks if one permission is a parent of another
|
||||
func (pm *PermissionManager) isPermissionParent(parent, child string) bool {
|
||||
childPerm, exists := pm.hierarchy.permissions[child]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Traverse up the hierarchy
|
||||
current := childPerm.Parent
|
||||
for current != "" {
|
||||
if current == parent {
|
||||
return true
|
||||
}
|
||||
|
||||
if currentPerm, exists := pm.hierarchy.permissions[current]; exists {
|
||||
current = currentPerm.Parent
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getInheritedPermissions gets permissions that could grant the requested permission
|
||||
func (pm *PermissionManager) getInheritedPermissions(permission string) []string {
|
||||
var inherited []string
|
||||
|
||||
perm, exists := pm.hierarchy.permissions[permission]
|
||||
if !exists {
|
||||
return inherited
|
||||
}
|
||||
|
||||
// Get all parent permissions
|
||||
current := perm.Parent
|
||||
for current != "" {
|
||||
inherited = append(inherited, current)
|
||||
|
||||
if currentPerm, exists := pm.hierarchy.permissions[current]; exists {
|
||||
current = currentPerm.Parent
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return inherited
|
||||
}
|
||||
|
||||
// countGrantedPermissions counts granted permissions in bulk results
|
||||
func (pm *PermissionManager) countGrantedPermissions(results map[string]*PermissionEvaluation) int {
|
||||
count := 0
|
||||
for _, eval := range results {
|
||||
if eval.Granted {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// GetPermissionHierarchy returns the current permission hierarchy
|
||||
func (pm *PermissionManager) GetPermissionHierarchy() *PermissionHierarchy {
|
||||
return pm.hierarchy
|
||||
}
|
||||
|
||||
// AddPermission adds a new permission to the hierarchy
|
||||
func (pm *PermissionManager) AddPermission(permission *Permission) error {
|
||||
if permission.Name == "" {
|
||||
return errors.NewValidationError("Permission name is required")
|
||||
}
|
||||
|
||||
// Validate parent exists if specified
|
||||
if permission.Parent != "" {
|
||||
if _, exists := pm.hierarchy.permissions[permission.Parent]; !exists {
|
||||
return errors.NewValidationError(fmt.Sprintf("Parent permission '%s' does not exist", permission.Parent))
|
||||
}
|
||||
}
|
||||
|
||||
pm.hierarchy.permissions[permission.Name] = permission
|
||||
pm.hierarchy.buildHierarchy()
|
||||
|
||||
pm.logger.Info("Permission added to hierarchy",
|
||||
zap.String("permission", permission.Name),
|
||||
zap.String("parent", permission.Parent))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRole adds a new role to the system
|
||||
func (pm *PermissionManager) AddRole(role *Role) error {
|
||||
if role.Name == "" {
|
||||
return errors.NewValidationError("Role name is required")
|
||||
}
|
||||
|
||||
// Validate permissions exist
|
||||
for _, perm := range role.Permissions {
|
||||
if _, exists := pm.hierarchy.permissions[perm]; !exists {
|
||||
return errors.NewValidationError(fmt.Sprintf("Permission '%s' does not exist", perm))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate inherited roles exist
|
||||
for _, inheritedRole := range role.Inherits {
|
||||
if _, exists := pm.hierarchy.roles[inheritedRole]; !exists {
|
||||
return errors.NewValidationError(fmt.Sprintf("Inherited role '%s' does not exist", inheritedRole))
|
||||
}
|
||||
}
|
||||
|
||||
pm.hierarchy.roles[role.Name] = role
|
||||
|
||||
pm.logger.Info("Role added to system",
|
||||
zap.String("role", role.Name),
|
||||
zap.Strings("permissions", role.Permissions))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPermissions returns all permissions sorted by hierarchy
|
||||
func (pm *PermissionManager) ListPermissions() []*Permission {
|
||||
permissions := make([]*Permission, 0, len(pm.hierarchy.permissions))
|
||||
|
||||
for _, perm := range pm.hierarchy.permissions {
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
// Sort by level and name
|
||||
sort.Slice(permissions, func(i, j int) bool {
|
||||
if permissions[i].Level != permissions[j].Level {
|
||||
return permissions[i].Level < permissions[j].Level
|
||||
}
|
||||
return permissions[i].Name < permissions[j].Name
|
||||
})
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
// ListRoles returns all roles
|
||||
func (pm *PermissionManager) ListRoles() []*Role {
|
||||
roles := make([]*Role, 0, len(pm.hierarchy.roles))
|
||||
|
||||
for _, role := range pm.hierarchy.roles {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
// Sort by name
|
||||
sort.Slice(roles, func(i, j int) bool {
|
||||
return roles[i].Name < roles[j].Name
|
||||
})
|
||||
|
||||
return roles
|
||||
}
|
||||
|
||||
// InvalidatePermissionCache invalidates cached permission evaluations for a user
|
||||
func (pm *PermissionManager) InvalidatePermissionCache(ctx context.Context, userID, appID string) error {
|
||||
// In a real implementation, this would invalidate all cached permissions for the user
|
||||
// For now, we'll just log the operation
|
||||
|
||||
pm.logger.Info("Invalidating permission cache",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("app_id", appID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPermissions returns all permissions sorted by hierarchy (for PermissionHierarchy)
|
||||
func (h *PermissionHierarchy) ListPermissions() []*Permission {
|
||||
permissions := make([]*Permission, 0, len(h.permissions))
|
||||
|
||||
for _, perm := range h.permissions {
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
// Sort by level and name
|
||||
sort.Slice(permissions, func(i, j int) bool {
|
||||
if permissions[i].Level != permissions[j].Level {
|
||||
return permissions[i].Level < permissions[j].Level
|
||||
}
|
||||
return permissions[i].Name < permissions[j].Name
|
||||
})
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
// ListRoles returns all roles (for PermissionHierarchy)
|
||||
func (h *PermissionHierarchy) ListRoles() []*Role {
|
||||
roles := make([]*Role, 0, len(h.roles))
|
||||
|
||||
for _, role := range h.roles {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
// Sort by name
|
||||
sort.Slice(roles, func(i, j int) bool {
|
||||
return roles[i].Name < roles[j].Name
|
||||
})
|
||||
|
||||
return roles
|
||||
}
|
||||
14
internal/cache/cache.go
vendored
14
internal/cache/cache.go
vendored
@ -153,8 +153,18 @@ type CacheManager struct {
|
||||
func NewCacheManager(config config.ConfigProvider, logger *zap.Logger) *CacheManager {
|
||||
var provider CacheProvider
|
||||
|
||||
// For now, we'll use memory cache. In production, this could be Redis
|
||||
provider = NewMemoryCache(config, logger)
|
||||
// Use Redis if configured, otherwise fall back to memory cache
|
||||
if config.GetBool("REDIS_ENABLED") {
|
||||
redisProvider, err := NewRedisCache(config, logger)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to initialize Redis cache, falling back to memory cache", zap.Error(err))
|
||||
provider = NewMemoryCache(config, logger)
|
||||
} else {
|
||||
provider = redisProvider
|
||||
}
|
||||
} else {
|
||||
provider = NewMemoryCache(config, logger)
|
||||
}
|
||||
|
||||
return &CacheManager{
|
||||
provider: provider,
|
||||
|
||||
191
internal/cache/redis.go
vendored
Normal file
191
internal/cache/redis.go
vendored
Normal file
@ -0,0 +1,191 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// RedisCache implements CacheProvider using Redis
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRedisCache creates a new Redis cache provider
|
||||
func NewRedisCache(config config.ConfigProvider, logger *zap.Logger) (CacheProvider, error) {
|
||||
// Redis configuration
|
||||
redisAddr := config.GetString("REDIS_ADDR")
|
||||
if redisAddr == "" {
|
||||
redisAddr = "localhost:6379"
|
||||
}
|
||||
|
||||
redisPassword := config.GetString("REDIS_PASSWORD")
|
||||
redisDB := config.GetInt("REDIS_DB")
|
||||
|
||||
// Create Redis client
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: redisAddr,
|
||||
Password: redisPassword,
|
||||
DB: redisDB,
|
||||
PoolSize: config.GetInt("REDIS_POOL_SIZE"),
|
||||
MinIdleConns: config.GetInt("REDIS_MIN_IDLE_CONNS"),
|
||||
MaxRetries: config.GetInt("REDIS_MAX_RETRIES"),
|
||||
DialTimeout: config.GetDuration("REDIS_DIAL_TIMEOUT"),
|
||||
ReadTimeout: config.GetDuration("REDIS_READ_TIMEOUT"),
|
||||
WriteTimeout: config.GetDuration("REDIS_WRITE_TIMEOUT"),
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
logger.Error("Failed to connect to Redis", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to connect to Redis").WithInternal(err)
|
||||
}
|
||||
|
||||
logger.Info("Connected to Redis successfully", zap.String("addr", redisAddr))
|
||||
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from Redis cache
|
||||
func (r *RedisCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
r.logger.Debug("Getting value from Redis cache", zap.String("key", key))
|
||||
|
||||
result, err := r.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, errors.NewNotFoundError("cache key")
|
||||
}
|
||||
r.logger.Error("Failed to get value from Redis", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to get cached value").WithInternal(err)
|
||||
}
|
||||
|
||||
return []byte(result), nil
|
||||
}
|
||||
|
||||
// Set stores a value in Redis cache with TTL
|
||||
func (r *RedisCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
r.logger.Debug("Setting value in Redis cache",
|
||||
zap.String("key", key),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
err := r.client.Set(ctx, key, value, ttl).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to set value in Redis", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to cache value").WithInternal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a value from Redis cache
|
||||
func (r *RedisCache) Delete(ctx context.Context, key string) error {
|
||||
r.logger.Debug("Deleting value from Redis cache", zap.String("key", key))
|
||||
|
||||
err := r.client.Del(ctx, key).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete value from Redis", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to delete cached value").WithInternal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in Redis cache
|
||||
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
count, err := r.client.Exists(ctx, key).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to check key existence in Redis", zap.Error(err))
|
||||
return false, errors.NewInternalError("Failed to check cache key existence").WithInternal(err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Clear removes all values from Redis cache (use with caution)
|
||||
func (r *RedisCache) Clear(ctx context.Context) error {
|
||||
r.logger.Warn("Clearing Redis cache - this will remove ALL cached data")
|
||||
|
||||
err := r.client.FlushDB(ctx).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to clear Redis cache", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to clear cache").WithInternal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the Redis connection
|
||||
func (r *RedisCache) Close() error {
|
||||
r.logger.Info("Closing Redis connection")
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// SetNX sets a key only if it doesn't exist (Redis-specific operation)
|
||||
func (r *RedisCache) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
r.logger.Debug("Setting value in Redis cache with NX",
|
||||
zap.String("key", key),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
result, err := r.client.SetNX(ctx, key, value, ttl).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to set NX value in Redis", zap.Error(err))
|
||||
return false, errors.NewInternalError("Failed to cache value with NX").WithInternal(err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Expire sets TTL for an existing key
|
||||
func (r *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||
r.logger.Debug("Setting TTL for Redis key",
|
||||
zap.String("key", key),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
result, err := r.client.Expire(ctx, key, ttl).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to set TTL in Redis", zap.Error(err))
|
||||
return errors.NewInternalError("Failed to set key TTL").WithInternal(err)
|
||||
}
|
||||
|
||||
if !result {
|
||||
return errors.NewNotFoundError("cache key")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time to live for a key
|
||||
func (r *RedisCache) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
ttl, err := r.client.TTL(ctx, key).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get TTL from Redis", zap.Error(err))
|
||||
return 0, errors.NewInternalError("Failed to get key TTL").WithInternal(err)
|
||||
}
|
||||
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
// Keys returns all keys matching a pattern
|
||||
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
keys, err := r.client.Keys(ctx, pattern).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to get keys from Redis", zap.Error(err))
|
||||
return nil, errors.NewInternalError("Failed to get cache keys").WithInternal(err)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
@ -117,6 +117,21 @@ func (c *Config) setDefaults() {
|
||||
"INTERNAL_HMAC_KEY": "bootstrap-hmac-key-change-in-production",
|
||||
"METRICS_ENABLED": "false",
|
||||
"METRICS_PORT": "9090",
|
||||
"REDIS_ENABLED": "false",
|
||||
"REDIS_ADDR": "localhost:6379",
|
||||
"REDIS_PASSWORD": "",
|
||||
"REDIS_DB": "0",
|
||||
"REDIS_POOL_SIZE": "10",
|
||||
"REDIS_MIN_IDLE_CONNS": "5",
|
||||
"REDIS_MAX_RETRIES": "3",
|
||||
"REDIS_DIAL_TIMEOUT": "5s",
|
||||
"REDIS_READ_TIMEOUT": "3s",
|
||||
"REDIS_WRITE_TIMEOUT": "3s",
|
||||
"MAX_AUTH_FAILURES": "5",
|
||||
"AUTH_FAILURE_WINDOW": "15m",
|
||||
"IP_BLOCK_DURATION": "1h",
|
||||
"REQUEST_MAX_AGE": "5m",
|
||||
"IP_WHITELIST": "",
|
||||
}
|
||||
|
||||
for key, value := range defaults {
|
||||
|
||||
@ -202,6 +202,11 @@ func NewAuthenticationError(message string) *AppError {
|
||||
return New(ErrUnauthorized, message)
|
||||
}
|
||||
|
||||
// NewConfigurationError creates a configuration error
|
||||
func NewConfigurationError(message string) *AppError {
|
||||
return New(ErrInternal, message)
|
||||
}
|
||||
|
||||
// ErrorResponse represents the JSON error response format
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
|
||||
394
internal/handlers/oauth2.go
Normal file
394
internal/handlers/oauth2.go
Normal file
@ -0,0 +1,394 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
)
|
||||
|
||||
// OAuth2Handler handles OAuth2/OIDC authentication flows
|
||||
type OAuth2Handler struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
oauth2Provider *auth.OAuth2Provider
|
||||
authService services.AuthenticationService
|
||||
}
|
||||
|
||||
// NewOAuth2Handler creates a new OAuth2 handler
|
||||
func NewOAuth2Handler(
|
||||
config config.ConfigProvider,
|
||||
logger *zap.Logger,
|
||||
authService services.AuthenticationService,
|
||||
) *OAuth2Handler {
|
||||
oauth2Provider := auth.NewOAuth2Provider(config, logger)
|
||||
|
||||
return &OAuth2Handler{
|
||||
config: config,
|
||||
logger: logger,
|
||||
oauth2Provider: oauth2Provider,
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeRequest represents the OAuth2 authorization request
|
||||
type AuthorizeRequest struct {
|
||||
RedirectURI string `json:"redirect_uri" validate:"required,url"`
|
||||
State string `json:"state,omitempty"`
|
||||
}
|
||||
|
||||
// AuthorizeResponse represents the OAuth2 authorization response
|
||||
type AuthorizeResponse struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"` // In production, this should be stored securely
|
||||
}
|
||||
|
||||
// CallbackRequest represents the OAuth2 callback request
|
||||
type CallbackRequest struct {
|
||||
Code string `json:"code" validate:"required"`
|
||||
State string `json:"state,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri" validate:"required,url"`
|
||||
CodeVerifier string `json:"code_verifier" validate:"required"`
|
||||
}
|
||||
|
||||
// CallbackResponse represents the OAuth2 callback response
|
||||
type CallbackResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
UserInfo *auth.UserInfo `json:"user_info"`
|
||||
JWTToken string `json:"jwt_token"`
|
||||
}
|
||||
|
||||
// RefreshRequest represents the token refresh request
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token" validate:"required"`
|
||||
}
|
||||
|
||||
// RefreshResponse represents the token refresh response
|
||||
type RefreshResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
JWTToken string `json:"jwt_token"`
|
||||
}
|
||||
|
||||
// RegisterRoutes registers OAuth2 routes
|
||||
func (h *OAuth2Handler) RegisterRoutes(router *mux.Router) {
|
||||
oauth2Router := router.PathPrefix("/oauth2").Subrouter()
|
||||
|
||||
oauth2Router.HandleFunc("/authorize", h.Authorize).Methods("POST")
|
||||
oauth2Router.HandleFunc("/callback", h.Callback).Methods("POST")
|
||||
oauth2Router.HandleFunc("/refresh", h.Refresh).Methods("POST")
|
||||
oauth2Router.HandleFunc("/userinfo", h.GetUserInfo).Methods("GET")
|
||||
}
|
||||
|
||||
// Authorize initiates the OAuth2 authorization flow
|
||||
func (h *OAuth2Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing OAuth2 authorization request")
|
||||
|
||||
var req AuthorizeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.logger.Warn("Invalid authorization request", zap.Error(err))
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate state if not provided
|
||||
if req.State == "" {
|
||||
state, err := h.generateState()
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate state", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.State = state
|
||||
}
|
||||
|
||||
// Generate authorization URL
|
||||
authURL, err := h.oauth2Provider.GenerateAuthURL(ctx, req.State, req.RedirectURI)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate authorization URL", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to generate authorization URL", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// In production, store the code verifier securely (e.g., in session or cache)
|
||||
// For now, we'll return it in the response
|
||||
codeVerifier, err := h.generateCodeVerifier()
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate code verifier", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := AuthorizeResponse{
|
||||
AuthURL: authURL,
|
||||
State: req.State,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Error("Failed to encode authorization response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Authorization URL generated successfully",
|
||||
zap.String("state", req.State),
|
||||
zap.String("redirect_uri", req.RedirectURI))
|
||||
}
|
||||
|
||||
// Callback handles the OAuth2 callback and exchanges code for tokens
|
||||
func (h *OAuth2Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing OAuth2 callback")
|
||||
|
||||
var req CallbackRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.logger.Warn("Invalid callback request", zap.Error(err))
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange authorization code for tokens
|
||||
tokenResp, err := h.oauth2Provider.ExchangeCodeForToken(ctx, req.Code, req.RedirectURI, req.CodeVerifier)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to exchange code for token", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to exchange authorization code", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user information
|
||||
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user info", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate internal JWT token for the user
|
||||
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate internal JWT token", zap.Error(err))
|
||||
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := CallbackResponse{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
UserInfo: userInfo,
|
||||
JWTToken: jwtToken,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Error("Failed to encode callback response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("OAuth2 callback processed successfully",
|
||||
zap.String("user_id", userInfo.Sub),
|
||||
zap.String("email", userInfo.Email))
|
||||
}
|
||||
|
||||
// Refresh refreshes an access token using refresh token
|
||||
func (h *OAuth2Handler) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing token refresh request")
|
||||
|
||||
var req RefreshRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.logger.Warn("Invalid refresh request", zap.Error(err))
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh the access token
|
||||
tokenResp, err := h.oauth2Provider.RefreshAccessToken(ctx, req.RefreshToken)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to refresh access token", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to refresh access token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get updated user information
|
||||
userInfo, err := h.oauth2Provider.GetUserInfo(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user info during refresh", zap.Error(err))
|
||||
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
http.Error(w, appErr.Message, appErr.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate new internal JWT token
|
||||
jwtToken, err := h.generateInternalJWTToken(ctx, userInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate internal JWT token during refresh", zap.Error(err))
|
||||
http.Error(w, "Failed to generate authentication token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := RefreshResponse{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
JWTToken: jwtToken,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Error("Failed to encode refresh response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("Token refresh completed successfully",
|
||||
zap.String("user_id", userInfo.Sub))
|
||||
}
|
||||
|
||||
// GetUserInfo retrieves user information from the current session
|
||||
func (h *OAuth2Handler) GetUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h.logger.Debug("Processing user info request")
|
||||
|
||||
// Extract JWT token from Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove "Bearer " prefix
|
||||
tokenString := authHeader
|
||||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
tokenString = authHeader[7:]
|
||||
}
|
||||
|
||||
// Validate JWT token
|
||||
authContext, err := h.authService.ValidateJWTToken(ctx, tokenString)
|
||||
if err != nil {
|
||||
h.logger.Warn("Invalid JWT token in user info request", zap.Error(err))
|
||||
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Return user information from JWT claims
|
||||
userInfo := map[string]interface{}{
|
||||
"sub": authContext.UserID,
|
||||
"email": authContext.Claims["email"],
|
||||
"name": authContext.Claims["name"],
|
||||
"permissions": authContext.Permissions,
|
||||
"app_id": authContext.AppID,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(userInfo); err != nil {
|
||||
h.logger.Error("Failed to encode user info response", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("User info request completed successfully",
|
||||
zap.String("user_id", authContext.UserID))
|
||||
}
|
||||
|
||||
// generateState generates a random state parameter for OAuth2
|
||||
func (h *OAuth2Handler) generateState() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a PKCE code verifier
|
||||
func (h *OAuth2Handler) generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateInternalJWTToken generates an internal JWT token for authenticated users
|
||||
func (h *OAuth2Handler) generateInternalJWTToken(ctx context.Context, userInfo *auth.UserInfo) (string, error) {
|
||||
// Create user token with information from OAuth2 provider
|
||||
userToken := &domain.UserToken{
|
||||
AppID: h.config.GetString("INTERNAL_APP_ID"),
|
||||
UserID: userInfo.Sub,
|
||||
Permissions: []string{"read", "write"}, // Default permissions, should be based on user roles
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour), // 24 hour expiration
|
||||
MaxValidAt: time.Now().Add(7 * 24 * time.Hour), // 7 days max validity
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"sub": userInfo.Sub,
|
||||
"email": userInfo.Email,
|
||||
"name": userInfo.Name,
|
||||
"email_verified": func() string {
|
||||
if userInfo.EmailVerified {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
// Generate JWT token using authentication service
|
||||
return h.authService.GenerateJWTToken(ctx, userToken)
|
||||
}
|
||||
423
internal/middleware/security.go
Normal file
423
internal/middleware/security.go
Normal file
@ -0,0 +1,423 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/errors"
|
||||
)
|
||||
|
||||
// SecurityMiddleware provides various security features
|
||||
type SecurityMiddleware struct {
|
||||
config config.ConfigProvider
|
||||
logger *zap.Logger
|
||||
cacheManager *cache.CacheManager
|
||||
rateLimiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware
|
||||
func NewSecurityMiddleware(config config.ConfigProvider, logger *zap.Logger) *SecurityMiddleware {
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
return &SecurityMiddleware{
|
||||
config: config,
|
||||
logger: logger,
|
||||
cacheManager: cacheManager,
|
||||
rateLimiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware implements per-IP rate limiting
|
||||
func (s *SecurityMiddleware) RateLimitMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GetBool("RATE_LIMIT_ENABLED") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get client IP
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Get or create rate limiter for this IP
|
||||
limiter := s.getRateLimiter(clientIP)
|
||||
|
||||
// Check if request is allowed
|
||||
if !limiter.Allow() {
|
||||
s.logger.Warn("Rate limit exceeded",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
// Track rate limit violations
|
||||
s.trackRateLimitViolation(clientIP)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(`{"error":"rate_limit_exceeded","message":"Too many requests"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// BruteForceProtectionMiddleware implements brute force protection
|
||||
func (s *SecurityMiddleware) BruteForceProtectionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Check if IP is temporarily blocked
|
||||
if s.isIPBlocked(clientIP) {
|
||||
s.logger.Warn("Blocked IP attempted access",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"ip_blocked","message":"IP temporarily blocked due to suspicious activity"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// IPWhitelistMiddleware implements IP whitelisting
|
||||
func (s *SecurityMiddleware) IPWhitelistMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
whitelist := s.config.GetStringSlice("IP_WHITELIST")
|
||||
if len(whitelist) == 0 {
|
||||
// No whitelist configured, allow all
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := s.getClientIP(r)
|
||||
|
||||
// Check if IP is in whitelist
|
||||
if !s.isIPInList(clientIP, whitelist) {
|
||||
s.logger.Warn("Non-whitelisted IP attempted access",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", r.URL.Path))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"ip_not_whitelisted","message":"IP not in whitelist"}`))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware adds security headers
|
||||
func (s *SecurityMiddleware) SecurityHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add security headers
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
|
||||
// Add HSTS header for HTTPS
|
||||
if r.TLS != nil {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthenticationFailureTracker tracks authentication failures for brute force protection
|
||||
func (s *SecurityMiddleware) TrackAuthenticationFailure(clientIP, userID string) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Track failures by IP
|
||||
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
s.incrementFailureCount(ctx, ipKey)
|
||||
|
||||
// Track failures by user ID if provided
|
||||
if userID != "" {
|
||||
userKey := cache.CacheKey("auth_failures_user", userID)
|
||||
s.incrementFailureCount(ctx, userKey)
|
||||
}
|
||||
|
||||
// Check if we should block the IP
|
||||
s.checkAndBlockIP(clientIP)
|
||||
}
|
||||
|
||||
// ClearAuthenticationFailures clears failure count on successful authentication
|
||||
func (s *SecurityMiddleware) ClearAuthenticationFailures(clientIP, userID string) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Clear failures by IP
|
||||
ipKey := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
s.cacheManager.Delete(ctx, ipKey)
|
||||
|
||||
// Clear failures by user ID if provided
|
||||
if userID != "" {
|
||||
userKey := cache.CacheKey("auth_failures_user", userID)
|
||||
s.cacheManager.Delete(ctx, userKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (s *SecurityMiddleware) getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header first
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
// Take the first IP in the chain
|
||||
ips := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
xri := r.Header.Get("X-Real-IP")
|
||||
if xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) getRateLimiter(clientIP string) *rate.Limiter {
|
||||
s.mu.RLock()
|
||||
limiter, exists := s.rateLimiters[clientIP]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Create new rate limiter
|
||||
rps := s.config.GetInt("RATE_LIMIT_RPS")
|
||||
if rps <= 0 {
|
||||
rps = 100 // Default
|
||||
}
|
||||
|
||||
burst := s.config.GetInt("RATE_LIMIT_BURST")
|
||||
if burst <= 0 {
|
||||
burst = 200 // Default
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rate.Limit(rps), burst)
|
||||
|
||||
s.mu.Lock()
|
||||
s.rateLimiters[clientIP] = limiter
|
||||
s.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) trackRateLimitViolation(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("rate_limit_violations", clientIP)
|
||||
s.incrementFailureCount(ctx, key)
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isIPBlocked(clientIP string) bool {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("blocked_ips", clientIP)
|
||||
|
||||
exists, err := s.cacheManager.Exists(ctx, key)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to check IP block status",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isIPInList(clientIP string, ipList []string) bool {
|
||||
for _, allowedIP := range ipList {
|
||||
allowedIP = strings.TrimSpace(allowedIP)
|
||||
|
||||
// Support CIDR notation
|
||||
if strings.Contains(allowedIP, "/") {
|
||||
_, network, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
s.logger.Warn("Invalid CIDR in IP list", zap.String("cidr", allowedIP))
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip != nil && network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// Exact IP match
|
||||
if clientIP == allowedIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) incrementFailureCount(ctx context.Context, key string) {
|
||||
// Get current count
|
||||
var count int
|
||||
err := s.cacheManager.GetJSON(ctx, key, &count)
|
||||
if err != nil {
|
||||
// Key doesn't exist, start with 0
|
||||
count = 0
|
||||
}
|
||||
|
||||
count++
|
||||
|
||||
// Store updated count with TTL
|
||||
ttl := s.config.GetDuration("AUTH_FAILURE_WINDOW")
|
||||
if ttl <= 0 {
|
||||
ttl = 15 * time.Minute // Default window
|
||||
}
|
||||
|
||||
s.cacheManager.SetJSON(ctx, key, count, ttl)
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) checkAndBlockIP(clientIP string) {
|
||||
ctx := context.Background()
|
||||
key := cache.CacheKey("auth_failures_ip", clientIP)
|
||||
|
||||
var count int
|
||||
err := s.cacheManager.GetJSON(ctx, key, &count)
|
||||
if err != nil {
|
||||
return // No failures recorded
|
||||
}
|
||||
|
||||
maxFailures := s.config.GetInt("MAX_AUTH_FAILURES")
|
||||
if maxFailures <= 0 {
|
||||
maxFailures = 5 // Default
|
||||
}
|
||||
|
||||
if count >= maxFailures {
|
||||
// Block the IP
|
||||
blockKey := cache.CacheKey("blocked_ips", clientIP)
|
||||
blockDuration := s.config.GetDuration("IP_BLOCK_DURATION")
|
||||
if blockDuration <= 0 {
|
||||
blockDuration = 1 * time.Hour // Default
|
||||
}
|
||||
|
||||
blockInfo := map[string]interface{}{
|
||||
"blocked_at": time.Now().Unix(),
|
||||
"failure_count": count,
|
||||
"reason": "excessive_auth_failures",
|
||||
}
|
||||
|
||||
s.cacheManager.SetJSON(ctx, blockKey, blockInfo, blockDuration)
|
||||
|
||||
s.logger.Warn("IP blocked due to excessive authentication failures",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Int("failure_count", count),
|
||||
zap.Duration("block_duration", blockDuration))
|
||||
}
|
||||
}
|
||||
|
||||
// RequestSignatureMiddleware validates request signatures (for API key requests)
|
||||
func (s *SecurityMiddleware) RequestSignatureMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only validate signatures for certain endpoints
|
||||
if !s.shouldValidateSignature(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
signature := r.Header.Get("X-Signature")
|
||||
timestamp := r.Header.Get("X-Timestamp")
|
||||
|
||||
if signature == "" || timestamp == "" {
|
||||
s.logger.Warn("Missing signature headers",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"missing_signature","message":"Request signature required"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate timestamp (prevent replay attacks)
|
||||
if !s.isTimestampValid(timestamp) {
|
||||
s.logger.Warn("Invalid timestamp in request",
|
||||
zap.String("timestamp", timestamp),
|
||||
zap.String("client_ip", s.getClientIP(r)))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"invalid_timestamp","message":"Request timestamp is invalid or too old"}`))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Implement actual signature validation
|
||||
// This would involve validating the HMAC signature using the client's secret
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) shouldValidateSignature(r *http.Request) bool {
|
||||
// Define which endpoints require signature validation
|
||||
signatureRequiredPaths := []string{
|
||||
"/api/v1/tokens",
|
||||
"/api/v1/applications",
|
||||
}
|
||||
|
||||
for _, path := range signatureRequiredPaths {
|
||||
if strings.HasPrefix(r.URL.Path, path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SecurityMiddleware) isTimestampValid(timestampStr string) bool {
|
||||
// Parse timestamp
|
||||
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if timestamp is within acceptable window
|
||||
now := time.Now()
|
||||
maxAge := s.config.GetDuration("REQUEST_MAX_AGE")
|
||||
if maxAge <= 0 {
|
||||
maxAge = 5 * time.Minute // Default
|
||||
}
|
||||
|
||||
return now.Sub(timestamp) <= maxAge && timestamp.Before(now.Add(1*time.Minute))
|
||||
}
|
||||
|
||||
// GetSecurityMetrics returns security-related metrics
|
||||
func (s *SecurityMiddleware) GetSecurityMetrics() map[string]interface{} {
|
||||
ctx := context.Background()
|
||||
|
||||
// This is a simplified version - in production you'd want more comprehensive metrics
|
||||
metrics := map[string]interface{}{
|
||||
"active_rate_limiters": len(s.rateLimiters),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
// Count blocked IPs (this is expensive, so you might want to cache this)
|
||||
// For now, we'll just return the basic metrics
|
||||
|
||||
return metrics
|
||||
}
|
||||
@ -14,231 +14,7 @@ import (
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
)
|
||||
|
||||
// MockConfig implements ConfigProvider for testing
|
||||
type MockConfig struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func NewMockConfig() *MockConfig {
|
||||
return &MockConfig{
|
||||
values: map[string]string{
|
||||
"JWT_SECRET": "test-jwt-secret-for-testing-only",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetString(key string) string {
|
||||
return m.values[key]
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetInt(key string) int { return 0 }
|
||||
func (m *MockConfig) GetBool(key string) bool { return false }
|
||||
func (m *MockConfig) GetDuration(key string) time.Duration { return 0 }
|
||||
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
|
||||
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
|
||||
func (m *MockConfig) Validate() error { return nil }
|
||||
func (m *MockConfig) GetDatabaseDSN() string { return "" }
|
||||
func (m *MockConfig) GetServerAddress() string { return "" }
|
||||
func (m *MockConfig) GetMetricsAddress() string { return "" }
|
||||
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
|
||||
func (m *MockConfig) IsDevelopment() bool { return true }
|
||||
func (m *MockConfig) IsProduction() bool { return false }
|
||||
|
||||
func TestJWTManager_GenerateToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenString)
|
||||
|
||||
// Verify the token can be validated
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
assert.Equal(t, userToken.Permissions, claims.Permissions)
|
||||
assert.Equal(t, userToken.TokenType, claims.TokenType)
|
||||
assert.Equal(t, userToken.Claims, claims.Claims)
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test valid token
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
|
||||
// Test invalid token
|
||||
_, err = jwtManager.ValidateToken("invalid-token")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test empty token
|
||||
_, err = jwtManager.ValidateToken("")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_ExpiredToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
// Create an expired token
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired 1 hour ago
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validation should fail for expired token
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_MaxValidAtExpired(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
// Create a token that's past max valid time
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(-time.Hour), // Max valid time expired
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validation should fail for token past max valid time
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestJWTManager_RefreshToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
originalToken, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh the token
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
refreshedToken, err := jwtManager.RefreshToken(originalToken, newExpiration)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshedToken)
|
||||
assert.NotEqual(t, originalToken, refreshedToken)
|
||||
|
||||
// Validate the refreshed token
|
||||
claims, err := jwtManager.ValidateToken(refreshedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_ExtractClaims(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired token
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract claims from expired token (should work)
|
||||
claims, err := jwtManager.ExtractClaims(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_GetTokenInfo(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
AppID: "test-app",
|
||||
UserID: "test-user",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
info := jwtManager.GetTokenInfo(tokenString)
|
||||
assert.Equal(t, userToken.UserID, info["user_id"])
|
||||
assert.Equal(t, userToken.AppID, info["app_id"])
|
||||
assert.Equal(t, userToken.Permissions, info["permissions"])
|
||||
assert.Equal(t, userToken.TokenType, info["token_type"])
|
||||
}
|
||||
|
||||
func TestAuthenticationService_ValidateJWTToken(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
@ -330,7 +106,8 @@ func TestAuthenticationService_RefreshJWTToken(t *testing.T) {
|
||||
|
||||
func TestJWTManager_InvalidSecret(t *testing.T) {
|
||||
// Test with empty JWT secret
|
||||
config := &MockConfig{values: map[string]string{"JWT_SECRET": ""}}
|
||||
config := NewTestConfig()
|
||||
config.values["JWT_SECRET"] = ""
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(config, logger)
|
||||
|
||||
|
||||
@ -12,48 +12,6 @@ import (
|
||||
"github.com/kms/api-key-service/internal/cache"
|
||||
)
|
||||
|
||||
// MockConfig implements ConfigProvider for testing
|
||||
type MockConfig struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func NewMockConfig() *MockConfig {
|
||||
return &MockConfig{
|
||||
values: map[string]string{
|
||||
"CACHE_ENABLED": "true",
|
||||
"CACHE_TTL": "1h",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetString(key string) string {
|
||||
return m.values[key]
|
||||
}
|
||||
|
||||
func (m *MockConfig) GetInt(key string) int { return 0 }
|
||||
func (m *MockConfig) GetBool(key string) bool {
|
||||
if key == "CACHE_ENABLED" {
|
||||
return m.values[key] == "true"
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (m *MockConfig) GetDuration(key string) time.Duration {
|
||||
if key == "CACHE_TTL" {
|
||||
if d, err := time.ParseDuration(m.values[key]); err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
func (m *MockConfig) GetStringSlice(key string) []string { return nil }
|
||||
func (m *MockConfig) IsSet(key string) bool { return m.values[key] != "" }
|
||||
func (m *MockConfig) Validate() error { return nil }
|
||||
func (m *MockConfig) GetDatabaseDSN() string { return "" }
|
||||
func (m *MockConfig) GetServerAddress() string { return "" }
|
||||
func (m *MockConfig) GetMetricsAddress() string { return "" }
|
||||
func (m *MockConfig) GetJWTSecret() string { return m.GetString("JWT_SECRET") }
|
||||
func (m *MockConfig) IsDevelopment() bool { return true }
|
||||
func (m *MockConfig) IsProduction() bool { return false }
|
||||
|
||||
func TestMemoryCache_SetAndGet(t *testing.T) {
|
||||
config := NewMockConfig()
|
||||
@ -315,12 +273,9 @@ func TestCacheKeyPrefixes(t *testing.T) {
|
||||
|
||||
func TestCacheManager_ConfigMethods(t *testing.T) {
|
||||
// Create mock config with cache settings
|
||||
config := &MockConfig{
|
||||
values: map[string]string{
|
||||
"CACHE_ENABLED": "true",
|
||||
"CACHE_TTL": "1h",
|
||||
},
|
||||
}
|
||||
config := NewMockConfig()
|
||||
config.values["CACHE_ENABLED"] = "true"
|
||||
config.values["CACHE_TTL"] = "1h"
|
||||
logger := zap.NewNop()
|
||||
cacheManager := cache.NewCacheManager(config, logger)
|
||||
defer cacheManager.Close()
|
||||
|
||||
382
test/jwt_test.go
Normal file
382
test/jwt_test.go
Normal file
@ -0,0 +1,382 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
)
|
||||
|
||||
func TestJWTManager_GenerateToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"department": "engineering",
|
||||
"role": "developer",
|
||||
},
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenString)
|
||||
|
||||
// Verify token structure (should have 3 parts separated by dots)
|
||||
parts := len(tokenString)
|
||||
assert.Greater(t, parts, 100) // JWT tokens are typically longer than 100 chars
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"department": "engineering",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token
|
||||
claims, err := jwtManager.ValidateToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
assert.Equal(t, userToken.Permissions, claims.Permissions)
|
||||
assert.Equal(t, userToken.TokenType, claims.TokenType)
|
||||
assert.Equal(t, userToken.Claims, claims.Claims)
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken_InvalidToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
// Test with invalid token
|
||||
_, err := jwtManager.ValidateToken("invalid.token.here")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid token")
|
||||
}
|
||||
|
||||
func TestJWTManager_ValidateToken_ExpiredToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired 1 hour ago
|
||||
MaxValidAt: time.Now().Add(-30 * time.Minute), // Max valid also expired
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token (this should work even with past dates)
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token (this should fail due to expiration)
|
||||
_, err = jwtManager.ValidateToken(tokenString)
|
||||
assert.Error(t, err)
|
||||
// The error could be either JWT expiration or our custom max valid check
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "expired beyond maximum validity") ||
|
||||
strings.Contains(err.Error(), "token is expired"),
|
||||
"Expected expiration error, got: %s", err.Error())
|
||||
}
|
||||
|
||||
func TestJWTManager_RefreshToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate original token
|
||||
originalToken, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh token with new expiration
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
refreshedToken, err := jwtManager.RefreshToken(originalToken, newExpiration)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, refreshedToken)
|
||||
assert.NotEqual(t, originalToken, refreshedToken)
|
||||
|
||||
// Validate refreshed token
|
||||
claims, err := jwtManager.ValidateToken(refreshedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
}
|
||||
|
||||
func TestJWTManager_RefreshToken_ExpiredMaxValid(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(-30 * time.Minute), // Max valid expired
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to refresh (should fail due to max valid expiration)
|
||||
newExpiration := time.Now().Add(2 * time.Hour)
|
||||
_, err = jwtManager.RefreshToken(tokenString, newExpiration)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired beyond maximum validity")
|
||||
}
|
||||
|
||||
func TestJWTManager_ExtractClaims(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired token
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate expired token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract claims (should work even for expired tokens)
|
||||
claims, err := jwtManager.ExtractClaims(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userToken.UserID, claims.UserID)
|
||||
assert.Equal(t, userToken.AppID, claims.AppID)
|
||||
assert.Equal(t, userToken.Permissions, claims.Permissions)
|
||||
}
|
||||
|
||||
func TestJWTManager_RevokeToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Revoke token
|
||||
err = jwtManager.RevokeToken(tokenString)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check if token is revoked
|
||||
revoked, err := jwtManager.IsTokenRevoked(tokenString)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, revoked)
|
||||
}
|
||||
|
||||
func TestJWTManager_RevokeToken_AlreadyExpired(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now().Add(-2 * time.Hour),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Already expired
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate expired token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Revoke expired token (should succeed but not add to blacklist)
|
||||
err = jwtManager.RevokeToken(tokenString)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check if token is revoked (should be false since it was already expired)
|
||||
revoked, err := jwtManager.IsTokenRevoked(tokenString)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, revoked)
|
||||
}
|
||||
|
||||
func TestJWTManager_IsTokenRevoked_NotRevoked(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check if token is revoked (should be false)
|
||||
revoked, err := jwtManager.IsTokenRevoked(tokenString)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, revoked)
|
||||
}
|
||||
|
||||
func TestJWTManager_GetTokenInfo(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
Claims: map[string]string{
|
||||
"department": "engineering",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get token info
|
||||
info := jwtManager.GetTokenInfo(tokenString)
|
||||
assert.Equal(t, userToken.UserID, info["user_id"])
|
||||
assert.Equal(t, userToken.AppID, info["app_id"])
|
||||
assert.Equal(t, userToken.Permissions, info["permissions"])
|
||||
assert.Equal(t, userToken.TokenType, info["token_type"])
|
||||
assert.NotNil(t, info["issued_at"])
|
||||
assert.NotNil(t, info["expires_at"])
|
||||
assert.NotNil(t, info["max_valid_at"])
|
||||
assert.NotNil(t, info["jti"])
|
||||
}
|
||||
|
||||
func TestJWTManager_GetTokenInfo_InvalidToken(t *testing.T) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
// Get info for invalid token
|
||||
info := jwtManager.GetTokenInfo("invalid.token.here")
|
||||
assert.Contains(t, info["error"], "Invalid token format")
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkJWTManager_GenerateToken(b *testing.B) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := jwtManager.GenerateToken(userToken)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJWTManager_ValidateToken(b *testing.B) {
|
||||
cfg := config.NewConfig()
|
||||
logger := zap.NewNop()
|
||||
jwtManager := auth.NewJWTManager(cfg, logger)
|
||||
|
||||
userToken := &domain.UserToken{
|
||||
UserID: "test-user-123",
|
||||
AppID: "test-app-456",
|
||||
Permissions: []string{"read", "write"},
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
MaxValidAt: time.Now().Add(24 * time.Hour),
|
||||
TokenType: domain.TokenTypeUser,
|
||||
}
|
||||
|
||||
tokenString, err := jwtManager.GenerateToken(userToken)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := jwtManager.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
552
test/oauth2_test.go
Normal file
552
test/oauth2_test.go
Normal file
@ -0,0 +1,552 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
)
|
||||
|
||||
func TestOAuth2Provider_GetDiscoveryDocument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providerURL string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedIssuer string
|
||||
}{
|
||||
{
|
||||
name: "successful discovery",
|
||||
providerURL: "https://example.com",
|
||||
mockResponse: `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo",
|
||||
"jwks_uri": "https://example.com/jwks"
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedIssuer: "https://example.com",
|
||||
},
|
||||
{
|
||||
name: "missing provider URL",
|
||||
providerURL: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid response status",
|
||||
providerURL: "https://example.com",
|
||||
mockResponse: `{"error": "not found"}`,
|
||||
mockStatusCode: http.StatusNotFound,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON response",
|
||||
providerURL: "https://example.com",
|
||||
mockResponse: `invalid json`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock server if needed
|
||||
var server *httptest.Server
|
||||
if tt.providerURL != "" && !tt.expectError {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/.well-known/openid_configuration", r.URL.Path)
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer server.Close()
|
||||
tt.providerURL = server.URL
|
||||
}
|
||||
|
||||
// Create config mock
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = tt.providerURL
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
discovery, err := provider.GetDiscoveryDocument(ctx)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, discovery)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, discovery)
|
||||
assert.Equal(t, tt.expectedIssuer, discovery.Issuer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_GenerateAuthURL(t *testing.T) {
|
||||
// Create mock discovery server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
state string
|
||||
redirectURI string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful URL generation",
|
||||
clientID: "test-client-id",
|
||||
state: "test-state",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
clientID: "",
|
||||
state: "test-state",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
authURL, err := provider.GenerateAuthURL(ctx, tt.state, tt.redirectURI)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, authURL)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, authURL)
|
||||
assert.Contains(t, authURL, "https://example.com/auth")
|
||||
assert.Contains(t, authURL, "client_id="+tt.clientID)
|
||||
assert.Contains(t, authURL, "state="+tt.state)
|
||||
assert.Contains(t, authURL, "redirect_uri=")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_ExchangeCodeForToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
redirectURI string
|
||||
codeVerifier string
|
||||
clientID string
|
||||
clientSecret string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedToken string
|
||||
}{
|
||||
{
|
||||
name: "successful token exchange",
|
||||
code: "test-code",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
codeVerifier: "test-verifier",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{
|
||||
"access_token": "test-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "test-refresh-token"
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedToken: "test-access-token",
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
code: "test-code",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
codeVerifier: "test-verifier",
|
||||
clientID: "",
|
||||
clientSecret: "test-client-secret",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "token endpoint error",
|
||||
code: "test-code",
|
||||
redirectURI: "https://app.example.com/callback",
|
||||
codeVerifier: "test-verifier",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{"error": "invalid_grant"}`,
|
||||
mockStatusCode: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock servers
|
||||
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer discoveryServer.Close()
|
||||
|
||||
var tokenServer *httptest.Server
|
||||
if !tt.expectError {
|
||||
tokenServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
||||
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
// Update discovery server to return the token server URL
|
||||
discoveryServer.Close()
|
||||
discoveryServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "` + tokenServer.URL + `",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
}
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
||||
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
tokenResp, err := provider.ExchangeCodeForToken(ctx, tt.code, tt.redirectURI, tt.codeVerifier)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokenResp)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokenResp)
|
||||
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
|
||||
assert.Equal(t, "Bearer", tokenResp.TokenType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_GetUserInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedSub string
|
||||
expectedEmail string
|
||||
}{
|
||||
{
|
||||
name: "successful user info retrieval",
|
||||
accessToken: "test-access-token",
|
||||
mockResponse: `{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
"email_verified": true
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedSub: "user123",
|
||||
expectedEmail: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "unauthorized access token",
|
||||
accessToken: "invalid-token",
|
||||
mockResponse: `{"error": "invalid_token"}`,
|
||||
mockStatusCode: http.StatusUnauthorized,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON response",
|
||||
accessToken: "test-access-token",
|
||||
mockResponse: `invalid json`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock servers
|
||||
userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
assert.Equal(t, "Bearer "+tt.accessToken, r.Header.Get("Authorization"))
|
||||
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer userInfoServer.Close()
|
||||
|
||||
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "` + userInfoServer.URL + `"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer discoveryServer.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
userInfo, err := provider.GetUserInfo(ctx, tt.accessToken)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, userInfo)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, userInfo)
|
||||
assert.Equal(t, tt.expectedSub, userInfo.Sub)
|
||||
assert.Equal(t, tt.expectedEmail, userInfo.Email)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_ValidateIDToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
expectError bool
|
||||
expectedSub string
|
||||
}{
|
||||
{
|
||||
name: "valid ID token",
|
||||
// This is a mock JWT token with payload: {"sub": "user123", "email": "user@example.com", "name": "Test User"}
|
||||
idToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIiwiZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwibmFtZSI6IlRlc3QgVXNlciJ9.invalid-signature",
|
||||
expectError: false,
|
||||
expectedSub: "user123",
|
||||
},
|
||||
{
|
||||
name: "invalid token format",
|
||||
idToken: "invalid.token",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
idToken: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configMock := NewMockConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
authContext, err := provider.ValidateIDToken(ctx, tt.idToken)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, authContext)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, authContext)
|
||||
assert.Equal(t, tt.expectedSub, authContext.UserID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Provider_RefreshAccessToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
clientID string
|
||||
clientSecret string
|
||||
mockResponse string
|
||||
mockStatusCode int
|
||||
expectError bool
|
||||
expectedToken string
|
||||
}{
|
||||
{
|
||||
name: "successful token refresh",
|
||||
refreshToken: "test-refresh-token",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "new-refresh-token"
|
||||
}`,
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
expectedToken: "new-access-token",
|
||||
},
|
||||
{
|
||||
name: "invalid refresh token",
|
||||
refreshToken: "invalid-refresh-token",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
mockResponse: `{"error": "invalid_grant"}`,
|
||||
mockStatusCode: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock servers
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
||||
|
||||
w.WriteHeader(tt.mockStatusCode)
|
||||
w.Write([]byte(tt.mockResponse))
|
||||
}))
|
||||
defer tokenServer.Close()
|
||||
|
||||
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "` + tokenServer.URL + `",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer discoveryServer.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
||||
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
tokenResp, err := provider.RefreshAccessToken(ctx, tt.refreshToken)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, tokenResp)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tokenResp)
|
||||
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
|
||||
assert.Equal(t, "Bearer", tokenResp.TokenType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for OAuth2 operations
|
||||
func BenchmarkOAuth2Provider_GetDiscoveryDocument(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.GetDiscoveryDocument(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOAuth2Provider_GenerateAuthURL(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
configMock := NewMockConfig()
|
||||
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
||||
configMock.values["SSO_CLIENT_ID"] = "test-client-id"
|
||||
|
||||
logger := zap.NewNop()
|
||||
provider := auth.NewOAuth2Provider(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.GenerateAuthURL(ctx, "test-state", "https://app.example.com/callback")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
594
test/permissions_test.go
Normal file
594
test/permissions_test.go
Normal file
@ -0,0 +1,594 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/auth"
|
||||
)
|
||||
|
||||
func TestPermissionHierarchy_InitializeDefaultPermissions(t *testing.T) {
|
||||
hierarchy := auth.NewPermissionHierarchy()
|
||||
|
||||
// Test that default permissions are created
|
||||
permissions := hierarchy.ListPermissions()
|
||||
assert.NotEmpty(t, permissions)
|
||||
|
||||
// Test specific permissions exist
|
||||
permissionNames := make(map[string]bool)
|
||||
for _, perm := range permissions {
|
||||
permissionNames[perm.Name] = true
|
||||
}
|
||||
|
||||
expectedPermissions := []string{
|
||||
"admin", "read", "write",
|
||||
"app.admin", "app.read", "app.write", "app.create", "app.update", "app.delete",
|
||||
"token.admin", "token.read", "token.write", "token.create", "token.revoke", "token.verify",
|
||||
"permission.admin", "permission.read", "permission.write", "permission.grant", "permission.revoke",
|
||||
"user.admin", "user.read", "user.write",
|
||||
}
|
||||
|
||||
for _, expected := range expectedPermissions {
|
||||
assert.True(t, permissionNames[expected], "Permission %s should exist", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHierarchy_InitializeDefaultRoles(t *testing.T) {
|
||||
hierarchy := auth.NewPermissionHierarchy()
|
||||
|
||||
// Test that default roles are created
|
||||
roles := hierarchy.ListRoles()
|
||||
assert.NotEmpty(t, roles)
|
||||
|
||||
// Test specific roles exist
|
||||
roleNames := make(map[string]bool)
|
||||
for _, role := range roles {
|
||||
roleNames[role.Name] = true
|
||||
}
|
||||
|
||||
expectedRoles := []string{
|
||||
"super_admin", "app_admin", "developer", "viewer", "token_manager",
|
||||
}
|
||||
|
||||
for _, expected := range expectedRoles {
|
||||
assert.True(t, roleNames[expected], "Role %s should exist", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_HasPermission(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false" // Disable cache for testing
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
appID string
|
||||
permission string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "admin user has admin permission",
|
||||
userID: "admin@example.com",
|
||||
appID: "test-app",
|
||||
permission: "admin",
|
||||
expectedResult: true,
|
||||
description: "Admin users should have admin permissions",
|
||||
},
|
||||
{
|
||||
name: "developer user has token.create permission",
|
||||
userID: "dev@example.com",
|
||||
appID: "test-app",
|
||||
permission: "token.create",
|
||||
expectedResult: true,
|
||||
description: "Developer users should have token creation permissions",
|
||||
},
|
||||
{
|
||||
name: "viewer user has read permission",
|
||||
userID: "viewer@example.com",
|
||||
appID: "test-app",
|
||||
permission: "app.read",
|
||||
expectedResult: true,
|
||||
description: "Viewer users should have read permissions",
|
||||
},
|
||||
{
|
||||
name: "viewer user denied write permission",
|
||||
userID: "viewer@example.com",
|
||||
appID: "test-app",
|
||||
permission: "app.write",
|
||||
expectedResult: false,
|
||||
description: "Viewer users should not have write permissions",
|
||||
},
|
||||
{
|
||||
name: "non-existent permission",
|
||||
userID: "admin@example.com",
|
||||
appID: "test-app",
|
||||
permission: "non.existent",
|
||||
expectedResult: false,
|
||||
description: "Non-existent permissions should be denied",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
evaluation, err := pm.HasPermission(ctx, tt.userID, tt.appID, tt.permission)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, evaluation)
|
||||
assert.Equal(t, tt.expectedResult, evaluation.Granted, tt.description)
|
||||
assert.Equal(t, tt.permission, evaluation.Permission)
|
||||
assert.NotZero(t, evaluation.EvaluatedAt)
|
||||
|
||||
if evaluation.Granted {
|
||||
assert.NotEmpty(t, evaluation.GrantedBy, "Granted permissions should have GrantedBy information")
|
||||
} else {
|
||||
assert.NotEmpty(t, evaluation.DeniedReason, "Denied permissions should have a reason")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_EvaluateBulkPermissions(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
req := &auth.BulkPermissionRequest{
|
||||
UserID: "dev@example.com",
|
||||
AppID: "test-app",
|
||||
Permissions: []string{
|
||||
"app.read",
|
||||
"token.create",
|
||||
"token.read",
|
||||
"app.delete", // Should be denied for developer
|
||||
"admin", // Should be denied for developer
|
||||
},
|
||||
}
|
||||
|
||||
response, err := pm.EvaluateBulkPermissions(ctx, req)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.Equal(t, req.UserID, response.UserID)
|
||||
assert.Equal(t, req.AppID, response.AppID)
|
||||
assert.Len(t, response.Results, len(req.Permissions))
|
||||
|
||||
// Check specific results
|
||||
assert.True(t, response.Results["app.read"].Granted, "Developer should have app.read permission")
|
||||
assert.True(t, response.Results["token.create"].Granted, "Developer should have token.create permission")
|
||||
assert.True(t, response.Results["token.read"].Granted, "Developer should have token.read permission")
|
||||
assert.False(t, response.Results["app.delete"].Granted, "Developer should not have app.delete permission")
|
||||
assert.False(t, response.Results["admin"].Granted, "Developer should not have admin permission")
|
||||
}
|
||||
|
||||
func TestPermissionManager_AddPermission(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
permission *auth.Permission
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "add valid permission",
|
||||
permission: &auth.Permission{
|
||||
Name: "custom.permission",
|
||||
Description: "Custom permission for testing",
|
||||
Parent: "read",
|
||||
Level: 2,
|
||||
Resource: "custom",
|
||||
Action: "test",
|
||||
},
|
||||
expectError: false,
|
||||
description: "Valid permissions should be added successfully",
|
||||
},
|
||||
{
|
||||
name: "add permission without name",
|
||||
permission: &auth.Permission{
|
||||
Description: "Permission without name",
|
||||
Parent: "read",
|
||||
Level: 2,
|
||||
},
|
||||
expectError: true,
|
||||
description: "Permissions without names should be rejected",
|
||||
},
|
||||
{
|
||||
name: "add permission with non-existent parent",
|
||||
permission: &auth.Permission{
|
||||
Name: "invalid.permission",
|
||||
Description: "Permission with invalid parent",
|
||||
Parent: "non.existent",
|
||||
Level: 2,
|
||||
},
|
||||
expectError: true,
|
||||
description: "Permissions with non-existent parents should be rejected",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := pm.AddPermission(tt.permission)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
} else {
|
||||
assert.NoError(t, err, tt.description)
|
||||
|
||||
// Verify permission was added
|
||||
permissions := pm.ListPermissions()
|
||||
found := false
|
||||
for _, perm := range permissions {
|
||||
if perm.Name == tt.permission.Name {
|
||||
found = true
|
||||
assert.Equal(t, tt.permission.Description, perm.Description)
|
||||
assert.Equal(t, tt.permission.Parent, perm.Parent)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Added permission should be found in the list")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_AddRole(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
role *auth.Role
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "add valid role",
|
||||
role: &auth.Role{
|
||||
Name: "custom_role",
|
||||
Description: "Custom role for testing",
|
||||
Permissions: []string{"read", "app.read"},
|
||||
Metadata: map[string]string{"level": "custom"},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Valid roles should be added successfully",
|
||||
},
|
||||
{
|
||||
name: "add role without name",
|
||||
role: &auth.Role{
|
||||
Description: "Role without name",
|
||||
Permissions: []string{"read"},
|
||||
},
|
||||
expectError: true,
|
||||
description: "Roles without names should be rejected",
|
||||
},
|
||||
{
|
||||
name: "add role with non-existent permission",
|
||||
role: &auth.Role{
|
||||
Name: "invalid_role",
|
||||
Description: "Role with invalid permission",
|
||||
Permissions: []string{"non.existent.permission"},
|
||||
},
|
||||
expectError: true,
|
||||
description: "Roles with non-existent permissions should be rejected",
|
||||
},
|
||||
{
|
||||
name: "add role with non-existent inherited role",
|
||||
role: &auth.Role{
|
||||
Name: "invalid_inherited_role",
|
||||
Description: "Role with invalid inheritance",
|
||||
Permissions: []string{"read"},
|
||||
Inherits: []string{"non_existent_role"},
|
||||
},
|
||||
expectError: true,
|
||||
description: "Roles with non-existent inherited roles should be rejected",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := pm.AddRole(tt.role)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
} else {
|
||||
assert.NoError(t, err, tt.description)
|
||||
|
||||
// Verify role was added
|
||||
roles := pm.ListRoles()
|
||||
found := false
|
||||
for _, role := range roles {
|
||||
if role.Name == tt.role.Name {
|
||||
found = true
|
||||
assert.Equal(t, tt.role.Description, role.Description)
|
||||
assert.Equal(t, tt.role.Permissions, role.Permissions)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Added role should be found in the list")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_ListPermissions(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
permissions := pm.ListPermissions()
|
||||
|
||||
// Should have default permissions
|
||||
assert.NotEmpty(t, permissions)
|
||||
|
||||
// Should be sorted by level and name
|
||||
for i := 1; i < len(permissions); i++ {
|
||||
prev := permissions[i-1]
|
||||
curr := permissions[i]
|
||||
|
||||
if prev.Level == curr.Level {
|
||||
assert.True(t, prev.Name <= curr.Name, "Permissions at same level should be sorted by name")
|
||||
} else {
|
||||
assert.True(t, prev.Level <= curr.Level, "Permissions should be sorted by level")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify hierarchy structure
|
||||
for _, perm := range permissions {
|
||||
if perm.Parent != "" {
|
||||
// Find parent permission
|
||||
parentFound := false
|
||||
for _, parent := range permissions {
|
||||
if parent.Name == perm.Parent {
|
||||
parentFound = true
|
||||
assert.True(t, parent.Level < perm.Level, "Parent should have lower level than child")
|
||||
assert.Contains(t, parent.Children, perm.Name, "Parent should contain child in children list")
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, parentFound, "Parent permission should exist for %s", perm.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_ListRoles(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
roles := pm.ListRoles()
|
||||
|
||||
// Should have default roles
|
||||
assert.NotEmpty(t, roles)
|
||||
|
||||
// Should be sorted by name
|
||||
for i := 1; i < len(roles); i++ {
|
||||
assert.True(t, roles[i-1].Name <= roles[i].Name, "Roles should be sorted by name")
|
||||
}
|
||||
|
||||
// Verify all permissions in roles exist
|
||||
allPermissions := pm.ListPermissions()
|
||||
permissionNames := make(map[string]bool)
|
||||
for _, perm := range allPermissions {
|
||||
permissionNames[perm.Name] = true
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
for _, perm := range role.Permissions {
|
||||
assert.True(t, permissionNames[perm], "Role %s references non-existent permission %s", role.Name, perm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionManager_InvalidatePermissionCache(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
err := pm.InvalidatePermissionCache(ctx, "user123", "app123")
|
||||
|
||||
// Should not error (currently just logs)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPermissionHierarchy_BuildHierarchy(t *testing.T) {
|
||||
hierarchy := auth.NewPermissionHierarchy()
|
||||
|
||||
// Test that parent-child relationships are built correctly
|
||||
permissions := hierarchy.ListPermissions()
|
||||
|
||||
// Find admin permission
|
||||
var adminPerm *auth.Permission
|
||||
for _, perm := range permissions {
|
||||
if perm.Name == "admin" {
|
||||
adminPerm = perm
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, adminPerm, "Admin permission should exist")
|
||||
|
||||
// Admin should have children
|
||||
assert.NotEmpty(t, adminPerm.Children, "Admin permission should have children")
|
||||
|
||||
// Check that app.admin is a child of admin
|
||||
assert.Contains(t, adminPerm.Children, "app.admin", "app.admin should be a child of admin")
|
||||
|
||||
// Find app.write permission
|
||||
var appWritePerm *auth.Permission
|
||||
for _, perm := range permissions {
|
||||
if perm.Name == "app.write" {
|
||||
appWritePerm = perm
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, appWritePerm, "app.write permission should exist")
|
||||
|
||||
// app.write should have children
|
||||
assert.NotEmpty(t, appWritePerm.Children, "app.write permission should have children")
|
||||
assert.Contains(t, appWritePerm.Children, "app.create", "app.create should be a child of app.write")
|
||||
assert.Contains(t, appWritePerm.Children, "app.update", "app.update should be a child of app.write")
|
||||
assert.Contains(t, appWritePerm.Children, "app.delete", "app.delete should be a child of app.write")
|
||||
}
|
||||
|
||||
// Benchmark tests for permission operations
|
||||
func BenchmarkPermissionManager_HasPermission(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := pm.HasPermission(ctx, "dev@example.com", "test-app", "token.create")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPermissionManager_EvaluateBulkPermissions(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &auth.BulkPermissionRequest{
|
||||
UserID: "dev@example.com",
|
||||
AppID: "test-app",
|
||||
Permissions: []string{
|
||||
"app.read", "token.create", "token.read", "app.delete", "admin",
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := pm.EvaluateBulkPermissions(ctx, req)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPermissionManager_ListPermissions(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
permissions := pm.ListPermissions()
|
||||
if len(permissions) == 0 {
|
||||
b.Fatal("No permissions returned")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPermissionManager_ListRoles(b *testing.B) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
roles := pm.ListRoles()
|
||||
if len(roles) == 0 {
|
||||
b.Fatal("No roles returned")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test permission hierarchy traversal
|
||||
func TestPermissionHierarchy_PermissionInheritance(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
configMock.values["CACHE_ENABLED"] = "false"
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
// Test that admin users get hierarchical permissions
|
||||
ctx := context.Background()
|
||||
|
||||
// Admin should have all permissions through hierarchy
|
||||
adminPermissions := []string{
|
||||
"admin",
|
||||
"app.admin",
|
||||
"token.admin",
|
||||
"permission.admin",
|
||||
"user.admin",
|
||||
}
|
||||
|
||||
for _, perm := range adminPermissions {
|
||||
evaluation, err := pm.HasPermission(ctx, "admin@example.com", "test-app", perm)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, evaluation.Granted, "Admin should have %s permission", perm)
|
||||
}
|
||||
}
|
||||
|
||||
// Test role inheritance
|
||||
func TestPermissionManager_RoleInheritance(t *testing.T) {
|
||||
configMock := NewTestConfig()
|
||||
|
||||
logger := zap.NewNop()
|
||||
pm := auth.NewPermissionManager(configMock, logger)
|
||||
|
||||
// Add a role that inherits from another role
|
||||
parentRole := &auth.Role{
|
||||
Name: "base_role",
|
||||
Description: "Base role with basic permissions",
|
||||
Permissions: []string{"read", "app.read"},
|
||||
Metadata: map[string]string{"level": "base"},
|
||||
}
|
||||
|
||||
childRole := &auth.Role{
|
||||
Name: "extended_role",
|
||||
Description: "Extended role that inherits from base",
|
||||
Permissions: []string{"write"},
|
||||
Inherits: []string{"base_role"},
|
||||
Metadata: map[string]string{"level": "extended"},
|
||||
}
|
||||
|
||||
err := pm.AddRole(parentRole)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pm.AddRole(childRole)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify roles were added
|
||||
roles := pm.ListRoles()
|
||||
roleNames := make(map[string]*auth.Role)
|
||||
for _, role := range roles {
|
||||
roleNames[role.Name] = role
|
||||
}
|
||||
|
||||
assert.Contains(t, roleNames, "base_role")
|
||||
assert.Contains(t, roleNames, "extended_role")
|
||||
assert.Equal(t, []string{"base_role"}, roleNames["extended_role"].Inherits)
|
||||
}
|
||||
@ -29,6 +29,10 @@ func (c *TestConfig) GetBool(key string) bool {
|
||||
return boolVal
|
||||
}
|
||||
}
|
||||
// Special handling for cache enabled
|
||||
if key == "CACHE_ENABLED" {
|
||||
return c.values[key] == "true"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@ -86,6 +90,10 @@ func (c *TestConfig) IsProduction() bool {
|
||||
return c.GetString("APP_ENV") == "production"
|
||||
}
|
||||
|
||||
func (c *TestConfig) GetJWTSecret() string {
|
||||
return c.GetString("JWT_SECRET")
|
||||
}
|
||||
|
||||
// NewTestConfig creates a test configuration with default values
|
||||
func NewTestConfig() *TestConfig {
|
||||
return &TestConfig{
|
||||
@ -99,6 +107,12 @@ func NewTestConfig() *TestConfig {
|
||||
"SERVER_HOST": "localhost",
|
||||
"SERVER_PORT": "8080",
|
||||
"APP_ENV": "test",
|
||||
"JWT_SECRET": "test-jwt-secret-for-testing-only",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockConfig creates a mock configuration (alias for NewTestConfig for backward compatibility)
|
||||
func NewMockConfig() *TestConfig {
|
||||
return NewTestConfig()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user