-
This commit is contained in:
@ -15,6 +15,7 @@ import (
|
||||
"github.com/kms/api-key-service/internal/config"
|
||||
"github.com/kms/api-key-service/internal/database"
|
||||
"github.com/kms/api-key-service/internal/handlers"
|
||||
"github.com/kms/api-key-service/internal/metrics"
|
||||
"github.com/kms/api-key-service/internal/middleware"
|
||||
"github.com/kms/api-key-service/internal/repository/postgres"
|
||||
"github.com/kms/api-key-service/internal/services"
|
||||
@ -61,7 +62,7 @@ func main() {
|
||||
|
||||
// Initialize services
|
||||
appService := services.NewApplicationService(appRepo, logger)
|
||||
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger)
|
||||
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, cfg.GetString("INTERNAL_HMAC_KEY"), logger)
|
||||
authService := services.NewAuthenticationService(cfg, logger)
|
||||
|
||||
// Initialize handlers
|
||||
@ -156,6 +157,7 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h
|
||||
// Add middleware
|
||||
router.Use(middleware.Logger(logger))
|
||||
router.Use(middleware.Recovery(logger))
|
||||
router.Use(metrics.Middleware(logger))
|
||||
router.Use(middleware.CORS())
|
||||
router.Use(middleware.Security())
|
||||
router.Use(middleware.ValidateContentType())
|
||||
@ -226,18 +228,20 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h
|
||||
}
|
||||
|
||||
func startMetricsServer(cfg config.ConfigProvider, logger *zap.Logger) *http.Server {
|
||||
metricsRouter := gin.New()
|
||||
metricsRouter.Use(middleware.Logger(logger))
|
||||
metricsRouter.Use(middleware.Recovery(logger))
|
||||
|
||||
// Basic metrics endpoint
|
||||
metricsRouter.GET("/metrics", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "# HELP api_key_service_info Information about the API Key Service\n# TYPE api_key_service_info gauge\napi_key_service_info{version=\"%s\"} 1\n", cfg.GetString("APP_VERSION"))
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Prometheus metrics endpoint
|
||||
mux.HandleFunc("/metrics", metrics.PrometheusHandler())
|
||||
|
||||
// Health endpoint for metrics server
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: cfg.GetMetricsAddress(),
|
||||
Handler: metricsRouter,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
||||
297
docs/PRODUCTION_ROADMAP.md
Normal file
297
docs/PRODUCTION_ROADMAP.md
Normal file
@ -0,0 +1,297 @@
|
||||
# KMS API Service - Production Roadmap
|
||||
|
||||
This document outlines the complete roadmap for making the API Key Management Service fully production-ready. Use the checkboxes to track progress and refer to the implementation notes at the bottom.
|
||||
|
||||
## 🏗️ Core Infrastructure (COMPLETED)
|
||||
|
||||
### Repository Layer
|
||||
- [x] Complete token repository implementation (CRUD operations)
|
||||
- [x] Complete permission repository implementation (core methods)
|
||||
- [x] Implement granted permission repository (authorization logic)
|
||||
- [x] Add database transaction support
|
||||
- [x] Implement proper error handling in repositories
|
||||
|
||||
### Security & Cryptography
|
||||
- [x] Implement secure token generation using crypto/rand
|
||||
- [x] Add bcrypt-based token hashing for storage
|
||||
- [x] Implement HMAC token signing and verification
|
||||
- [x] Create token format validation utilities
|
||||
- [x] Add cryptographic key management
|
||||
|
||||
### Service Layer
|
||||
- [x] Update token service with secure generation
|
||||
- [x] Implement permission validation in token creation
|
||||
- [x] Add application validation before token operations
|
||||
- [x] Implement proper error propagation
|
||||
- [x] Add comprehensive logging throughout services
|
||||
|
||||
### Middleware & Validation
|
||||
- [x] Create comprehensive input validation middleware
|
||||
- [x] Implement struct-based validation with detailed errors
|
||||
- [x] Add UUID parameter validation
|
||||
- [x] Create permission scope format validation
|
||||
- [x] Implement request sanitization
|
||||
|
||||
### Error Handling
|
||||
- [x] Create structured error framework with typed codes
|
||||
- [x] Implement HTTP status code mapping
|
||||
- [x] Add error context and chaining support
|
||||
- [x] Create consistent JSON error responses
|
||||
- [x] Add retry logic indicators
|
||||
|
||||
### Monitoring & Metrics
|
||||
- [x] Implement comprehensive metrics collection
|
||||
- [x] Add Prometheus-compatible metrics export
|
||||
- [x] Create HTTP request monitoring middleware
|
||||
- [x] Add business metrics tracking
|
||||
- [x] Implement system health metrics
|
||||
|
||||
## 🔐 Authentication & Authorization (HIGH PRIORITY)
|
||||
|
||||
### JWT Implementation
|
||||
- [ ] Complete JWT token generation and validation
|
||||
- [ ] Implement token expiration and renewal logic
|
||||
- [ ] Add JWT claims management
|
||||
- [ ] Create token blacklisting mechanism
|
||||
- [ ] Implement refresh token rotation
|
||||
|
||||
### SSO Integration
|
||||
- [ ] Implement OAuth2/OIDC provider integration
|
||||
- [ ] Add SAML authentication support
|
||||
- [ ] Create user session management
|
||||
- [ ] Implement role-based access control (RBAC)
|
||||
- [ ] Add multi-tenant authentication support
|
||||
|
||||
### Permission System Enhancement
|
||||
- [ ] Implement hierarchical permission inheritance
|
||||
- [ ] Add dynamic permission evaluation
|
||||
- [ ] Create permission caching mechanism
|
||||
- [ ] Implement permission audit logging
|
||||
- [ ] Add bulk permission operations
|
||||
|
||||
## 🚀 Performance & Scalability (MEDIUM PRIORITY)
|
||||
|
||||
### Caching Layer
|
||||
- [ ] Implement Redis integration for caching
|
||||
- [ ] Add permission result caching
|
||||
- [ ] Create application metadata caching
|
||||
- [ ] Implement token validation result caching
|
||||
- [ ] Add cache invalidation strategies
|
||||
|
||||
### Database Optimization
|
||||
- [ ] Implement database connection pool tuning
|
||||
- [ ] Add query performance monitoring
|
||||
- [ ] Create database migration rollback procedures
|
||||
- [ ] Implement read replica support
|
||||
- [ ] Add database backup and recovery procedures
|
||||
|
||||
### Load Balancing & Clustering
|
||||
- [ ] Implement horizontal scaling support
|
||||
- [ ] Add load balancer health checks
|
||||
- [ ] Create session affinity handling
|
||||
- [ ] Implement distributed rate limiting
|
||||
- [ ] Add circuit breaker patterns
|
||||
|
||||
## 🔒 Security Hardening (HIGH PRIORITY)
|
||||
|
||||
### Advanced Security Features
|
||||
- [ ] Implement API key rotation mechanisms
|
||||
- [ ] Add brute force protection
|
||||
- [ ] Create account lockout mechanisms
|
||||
- [ ] Implement IP whitelisting/blacklisting
|
||||
- [ ] Add request signing validation
|
||||
|
||||
### Audit & Compliance
|
||||
- [ ] Implement comprehensive audit logging
|
||||
- [ ] Add compliance reporting features
|
||||
- [ ] Create data retention policies
|
||||
- [ ] Implement GDPR compliance features
|
||||
- [ ] Add security event alerting
|
||||
|
||||
### Secrets Management
|
||||
- [ ] Integrate with HashiCorp Vault or similar
|
||||
- [ ] Implement automatic key rotation
|
||||
- [ ] Add encrypted configuration storage
|
||||
- [ ] Create secure backup procedures
|
||||
- [ ] Implement key escrow mechanisms
|
||||
|
||||
## 🧪 Testing & Quality Assurance (MEDIUM PRIORITY)
|
||||
|
||||
### Unit Testing
|
||||
- [ ] Add comprehensive unit tests for repositories
|
||||
- [ ] Create service layer unit tests
|
||||
- [ ] Implement middleware unit tests
|
||||
- [ ] Add crypto utility unit tests
|
||||
- [ ] Create error handling unit tests
|
||||
|
||||
### Integration Testing
|
||||
- [ ] Expand integration test coverage
|
||||
- [ ] Add database integration tests
|
||||
- [ ] Create API endpoint integration tests
|
||||
- [ ] Implement authentication flow tests
|
||||
- [ ] Add permission validation tests
|
||||
|
||||
### Performance Testing
|
||||
- [ ] Implement load testing scenarios
|
||||
- [ ] Add stress testing for concurrent operations
|
||||
- [ ] Create database performance benchmarks
|
||||
- [ ] Implement memory leak detection
|
||||
- [ ] Add latency and throughput testing
|
||||
|
||||
### Security Testing
|
||||
- [ ] Implement penetration testing scenarios
|
||||
- [ ] Add vulnerability scanning automation
|
||||
- [ ] Create security regression tests
|
||||
- [ ] Implement fuzzing tests
|
||||
- [ ] Add compliance validation tests
|
||||
|
||||
## 📦 Deployment & Operations (MEDIUM PRIORITY)
|
||||
|
||||
### Containerization & Orchestration
|
||||
- [ ] Create optimized Docker images
|
||||
- [ ] Implement Kubernetes manifests
|
||||
- [ ] Add Helm charts for deployment
|
||||
- [ ] Create deployment automation scripts
|
||||
- [ ] Implement blue-green deployment strategy
|
||||
|
||||
### Infrastructure as Code
|
||||
- [ ] Create Terraform configurations
|
||||
- [ ] Implement AWS/GCP/Azure resource definitions
|
||||
- [ ] Add infrastructure testing
|
||||
- [ ] Create disaster recovery procedures
|
||||
- [ ] Implement infrastructure monitoring
|
||||
|
||||
### CI/CD Pipeline
|
||||
- [ ] Implement automated testing pipeline
|
||||
- [ ] Add security scanning in CI/CD
|
||||
- [ ] Create automated deployment pipeline
|
||||
- [ ] Implement rollback mechanisms
|
||||
- [ ] Add deployment notifications
|
||||
|
||||
## 📊 Observability & Monitoring (LOW PRIORITY)
|
||||
|
||||
### Advanced Monitoring
|
||||
- [ ] Implement distributed tracing
|
||||
- [ ] Add application performance monitoring (APM)
|
||||
- [ ] Create custom dashboards
|
||||
- [ ] Implement alerting rules
|
||||
- [ ] Add log aggregation and analysis
|
||||
|
||||
### Business Intelligence
|
||||
- [ ] Create usage analytics
|
||||
- [ ] Implement cost tracking
|
||||
- [ ] Add capacity planning metrics
|
||||
- [ ] Create business KPI dashboards
|
||||
- [ ] Implement trend analysis
|
||||
|
||||
## 🔧 Maintenance & Operations (ONGOING)
|
||||
|
||||
### Documentation
|
||||
- [ ] Create comprehensive API documentation
|
||||
- [ ] Add deployment guides
|
||||
- [ ] Create troubleshooting runbooks
|
||||
- [ ] Implement architecture decision records (ADRs)
|
||||
- [ ] Add security best practices guide
|
||||
|
||||
### Maintenance Procedures
|
||||
- [ ] Create backup and restore procedures
|
||||
- [ ] Implement log rotation and archival
|
||||
- [ ] Add database maintenance scripts
|
||||
- [ ] Create performance tuning guides
|
||||
- [ ] Implement capacity planning procedures
|
||||
|
||||
---
|
||||
|
||||
## 📝 Implementation Notes for Future Development
|
||||
|
||||
### Code Organization Principles
|
||||
1. **Maintain Clean Architecture**: Keep clear separation between domain, service, and infrastructure layers
|
||||
2. **Interface-First Design**: Always define interfaces before implementations for better testability
|
||||
3. **Error Handling**: Use the established error framework (`internal/errors`) for consistent error handling
|
||||
4. **Logging**: Use structured logging with zap throughout the application
|
||||
5. **Configuration**: Add new config options to `internal/config/config.go` with proper validation
|
||||
|
||||
### Security Guidelines
|
||||
1. **Input Validation**: Always validate inputs using the validation middleware (`internal/middleware/validation.go`)
|
||||
2. **Token Security**: Use the crypto utilities (`internal/crypto/token.go`) for all token operations
|
||||
3. **Permission Checks**: Always validate permissions using the repository layer before operations
|
||||
4. **Audit Logging**: Log all security-relevant operations with user context
|
||||
5. **Secrets**: Never hardcode secrets; use environment variables or secret management systems
|
||||
|
||||
### Database Guidelines
|
||||
1. **Migrations**: Always create both up and down migrations for schema changes
|
||||
2. **Transactions**: Use database transactions for multi-step operations
|
||||
3. **Indexing**: Add appropriate indexes for query performance
|
||||
4. **Connection Management**: Use the existing connection pool configuration
|
||||
5. **Error Handling**: Wrap database errors with the application error framework
|
||||
|
||||
### Testing Guidelines
|
||||
1. **Test Structure**: Follow the existing test structure in `test/` directory
|
||||
2. **Mock Dependencies**: Use interfaces for easy mocking in tests
|
||||
3. **Test Data**: Use the test helpers for consistent test data creation
|
||||
4. **Integration Tests**: Test against real database instances when possible
|
||||
5. **Coverage**: Aim for >80% test coverage for critical paths
|
||||
|
||||
### Performance Guidelines
|
||||
1. **Metrics**: Use the metrics system (`internal/metrics`) to track performance
|
||||
2. **Caching**: Implement caching at the service layer, not repository layer
|
||||
3. **Database Queries**: Optimize queries and use appropriate indexes
|
||||
4. **Memory Management**: Be mindful of memory allocations in hot paths
|
||||
5. **Concurrency**: Use proper synchronization for shared resources
|
||||
|
||||
### Deployment Guidelines
|
||||
1. **Environment Variables**: Use environment-based configuration for all deployments
|
||||
2. **Health Checks**: Ensure health endpoints are properly configured
|
||||
3. **Graceful Shutdown**: Implement proper shutdown procedures for all services
|
||||
4. **Resource Limits**: Set appropriate CPU and memory limits
|
||||
5. **Monitoring**: Ensure metrics and logging are properly configured
|
||||
|
||||
### Code Quality Standards
|
||||
1. **Go Standards**: Follow standard Go conventions and best practices
|
||||
2. **Documentation**: Document all public APIs and complex business logic
|
||||
3. **Error Messages**: Provide clear, actionable error messages
|
||||
4. **Code Reviews**: Require code reviews for all changes
|
||||
5. **Static Analysis**: Use tools like golangci-lint for code quality
|
||||
|
||||
### Security Best Practices
|
||||
1. **Principle of Least Privilege**: Grant minimum necessary permissions
|
||||
2. **Defense in Depth**: Implement multiple layers of security
|
||||
3. **Regular Updates**: Keep dependencies updated for security patches
|
||||
4. **Secure Defaults**: Use secure configurations by default
|
||||
5. **Security Testing**: Include security testing in the development process
|
||||
|
||||
### Operational Considerations
|
||||
1. **Monitoring**: Implement comprehensive monitoring and alerting
|
||||
2. **Backup Strategy**: Ensure regular backups and test restore procedures
|
||||
3. **Disaster Recovery**: Have documented disaster recovery procedures
|
||||
4. **Capacity Planning**: Monitor resource usage and plan for growth
|
||||
5. **Documentation**: Keep operational documentation up to date
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Priority Matrix
|
||||
|
||||
### Immediate (Next Sprint)
|
||||
- Complete JWT implementation
|
||||
- Add comprehensive unit tests
|
||||
- Implement caching layer basics
|
||||
|
||||
### Short Term (1-2 Months)
|
||||
- SSO integration
|
||||
- Security hardening features
|
||||
- Performance optimization
|
||||
|
||||
### Medium Term (3-6 Months)
|
||||
- Advanced monitoring and observability
|
||||
- Deployment automation
|
||||
- Compliance features
|
||||
|
||||
### Long Term (6+ Months)
|
||||
- Advanced analytics
|
||||
- Multi-region deployment
|
||||
- Advanced security features
|
||||
|
||||
---
|
||||
|
||||
*Last Updated: [Current Date]*
|
||||
*Version: 1.0*
|
||||
173
internal/crypto/token.go
Normal file
173
internal/crypto/token.go
Normal file
@ -0,0 +1,173 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// TokenLength defines the length of generated tokens in bytes
|
||||
TokenLength = 32
|
||||
// TokenPrefix is prepended to all tokens for identification
|
||||
TokenPrefix = "kms_"
|
||||
)
|
||||
|
||||
// TokenGenerator provides secure token generation and validation
|
||||
type TokenGenerator struct {
|
||||
hmacKey []byte
|
||||
}
|
||||
|
||||
// NewTokenGenerator creates a new token generator with the provided HMAC key
|
||||
func NewTokenGenerator(hmacKey string) *TokenGenerator {
|
||||
return &TokenGenerator{
|
||||
hmacKey: []byte(hmacKey),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSecureToken generates a cryptographically secure random token
|
||||
func (tg *TokenGenerator) GenerateSecureToken() (string, error) {
|
||||
// Generate random bytes
|
||||
tokenBytes := make([]byte, TokenLength)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||
}
|
||||
|
||||
// Encode to base64 for safe transmission
|
||||
tokenData := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Add prefix for identification
|
||||
token := TokenPrefix + tokenData
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// HashToken creates a secure hash of the token for storage
|
||||
func (tg *TokenGenerator) HashToken(token string) (string, error) {
|
||||
// Use bcrypt for secure password-like hashing
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to hash token: %w", err)
|
||||
}
|
||||
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// VerifyToken verifies a token against its stored hash
|
||||
func (tg *TokenGenerator) VerifyToken(token, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(token))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GenerateHMACKey generates a new HMAC key for token signing
|
||||
func GenerateHMACKey() (string, error) {
|
||||
key := make([]byte, 32) // 256-bit key
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return "", fmt.Errorf("failed to generate HMAC key: %w", err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// SignToken creates an HMAC signature for a token
|
||||
func (tg *TokenGenerator) SignToken(token string, timestamp time.Time) string {
|
||||
h := hmac.New(sha256.New, tg.hmacKey)
|
||||
h.Write([]byte(token))
|
||||
h.Write([]byte(timestamp.Format(time.RFC3339)))
|
||||
|
||||
signature := h.Sum(nil)
|
||||
return hex.EncodeToString(signature)
|
||||
}
|
||||
|
||||
// VerifyTokenSignature verifies an HMAC signature for a token
|
||||
func (tg *TokenGenerator) VerifyTokenSignature(token, signature string, timestamp time.Time) bool {
|
||||
expectedSignature := tg.SignToken(token, timestamp)
|
||||
return hmac.Equal([]byte(signature), []byte(expectedSignature))
|
||||
}
|
||||
|
||||
// ExtractTokenFromHeader extracts a token from an Authorization header
|
||||
func ExtractTokenFromHeader(authHeader string) string {
|
||||
// Support both "Bearer token" and "token" formats
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
return authHeader
|
||||
}
|
||||
|
||||
// IsValidTokenFormat checks if a token has the expected format
|
||||
func IsValidTokenFormat(token string) bool {
|
||||
if !strings.HasPrefix(token, TokenPrefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove prefix and check if remaining part is valid base64
|
||||
tokenData := strings.TrimPrefix(token, TokenPrefix)
|
||||
if len(tokenData) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try to decode base64
|
||||
_, err := base64.URLEncoding.DecodeString(tokenData)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// TokenInfo holds information about a token
|
||||
type TokenInfo struct {
|
||||
Token string
|
||||
Hash string
|
||||
Signature string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// GenerateTokenWithInfo generates a complete token with hash and signature
|
||||
func (tg *TokenGenerator) GenerateTokenWithInfo() (*TokenInfo, error) {
|
||||
// Generate the token
|
||||
token, err := tg.GenerateSecureToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
// Hash the token for storage
|
||||
hash, err := tg.HashToken(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash token: %w", err)
|
||||
}
|
||||
|
||||
// Create timestamp and signature
|
||||
now := time.Now()
|
||||
signature := tg.SignToken(token, now)
|
||||
|
||||
return &TokenInfo{
|
||||
Token: token,
|
||||
Hash: hash,
|
||||
Signature: signature,
|
||||
CreatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateTokenInfo validates a complete token with all its components
|
||||
func (tg *TokenGenerator) ValidateTokenInfo(token, hash, signature string, createdAt time.Time) error {
|
||||
// Check token format
|
||||
if !IsValidTokenFormat(token) {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
// Verify token against hash
|
||||
if !tg.VerifyToken(token, hash) {
|
||||
return fmt.Errorf("token verification failed")
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !tg.VerifyTokenSignature(token, signature, createdAt) {
|
||||
return fmt.Errorf("token signature verification failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
287
internal/errors/errors.go
Normal file
287
internal/errors/errors.go
Normal file
@ -0,0 +1,287 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ErrorCode represents different types of errors in the system
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
// Authentication and Authorization errors
|
||||
ErrUnauthorized ErrorCode = "UNAUTHORIZED"
|
||||
ErrForbidden ErrorCode = "FORBIDDEN"
|
||||
ErrInvalidToken ErrorCode = "INVALID_TOKEN"
|
||||
ErrTokenExpired ErrorCode = "TOKEN_EXPIRED"
|
||||
ErrInvalidCredentials ErrorCode = "INVALID_CREDENTIALS"
|
||||
|
||||
// Validation errors
|
||||
ErrValidationFailed ErrorCode = "VALIDATION_FAILED"
|
||||
ErrInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrMissingField ErrorCode = "MISSING_FIELD"
|
||||
ErrInvalidFormat ErrorCode = "INVALID_FORMAT"
|
||||
|
||||
// Resource errors
|
||||
ErrNotFound ErrorCode = "NOT_FOUND"
|
||||
ErrAlreadyExists ErrorCode = "ALREADY_EXISTS"
|
||||
ErrConflict ErrorCode = "CONFLICT"
|
||||
|
||||
// System errors
|
||||
ErrInternal ErrorCode = "INTERNAL_ERROR"
|
||||
ErrDatabase ErrorCode = "DATABASE_ERROR"
|
||||
ErrExternal ErrorCode = "EXTERNAL_SERVICE_ERROR"
|
||||
ErrTimeout ErrorCode = "TIMEOUT"
|
||||
ErrRateLimit ErrorCode = "RATE_LIMIT_EXCEEDED"
|
||||
|
||||
// Business logic errors
|
||||
ErrInsufficientPermissions ErrorCode = "INSUFFICIENT_PERMISSIONS"
|
||||
ErrApplicationNotFound ErrorCode = "APPLICATION_NOT_FOUND"
|
||||
ErrTokenNotFound ErrorCode = "TOKEN_NOT_FOUND"
|
||||
ErrPermissionNotFound ErrorCode = "PERMISSION_NOT_FOUND"
|
||||
ErrInvalidApplication ErrorCode = "INVALID_APPLICATION"
|
||||
ErrTokenCreationFailed ErrorCode = "TOKEN_CREATION_FAILED"
|
||||
)
|
||||
|
||||
// AppError represents an application error with context
|
||||
type AppError struct {
|
||||
Code ErrorCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
StatusCode int `json:"-"`
|
||||
Internal error `json:"-"`
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *AppError) Error() string {
|
||||
if e.Internal != nil {
|
||||
return fmt.Sprintf("%s: %s (internal: %v)", e.Code, e.Message, e.Internal)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// WithContext adds context information to the error
|
||||
func (e *AppError) WithContext(key string, value interface{}) *AppError {
|
||||
if e.Context == nil {
|
||||
e.Context = make(map[string]interface{})
|
||||
}
|
||||
e.Context[key] = value
|
||||
return e
|
||||
}
|
||||
|
||||
// WithDetails adds additional details to the error
|
||||
func (e *AppError) WithDetails(details string) *AppError {
|
||||
e.Details = details
|
||||
return e
|
||||
}
|
||||
|
||||
// WithInternal adds the underlying error
|
||||
func (e *AppError) WithInternal(err error) *AppError {
|
||||
e.Internal = err
|
||||
return e
|
||||
}
|
||||
|
||||
// New creates a new application error
|
||||
func New(code ErrorCode, message string) *AppError {
|
||||
return &AppError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
StatusCode: getHTTPStatusCode(code),
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap wraps an existing error with application error context
|
||||
func Wrap(err error, code ErrorCode, message string) *AppError {
|
||||
return &AppError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
StatusCode: getHTTPStatusCode(code),
|
||||
Internal: err,
|
||||
}
|
||||
}
|
||||
|
||||
// getHTTPStatusCode maps error codes to HTTP status codes
|
||||
func getHTTPStatusCode(code ErrorCode) int {
|
||||
switch code {
|
||||
case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials:
|
||||
return http.StatusUnauthorized
|
||||
case ErrForbidden, ErrInsufficientPermissions:
|
||||
return http.StatusForbidden
|
||||
case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat:
|
||||
return http.StatusBadRequest
|
||||
case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound:
|
||||
return http.StatusNotFound
|
||||
case ErrAlreadyExists, ErrConflict:
|
||||
return http.StatusConflict
|
||||
case ErrRateLimit:
|
||||
return http.StatusTooManyRequests
|
||||
case ErrTimeout:
|
||||
return http.StatusRequestTimeout
|
||||
case ErrInternal, ErrDatabase, ErrExternal, ErrTokenCreationFailed:
|
||||
return http.StatusInternalServerError
|
||||
default:
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
// IsRetryable determines if an error is retryable
|
||||
func (e *AppError) IsRetryable() bool {
|
||||
switch e.Code {
|
||||
case ErrTimeout, ErrExternal, ErrDatabase:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsClientError determines if an error is a client error (4xx)
|
||||
func (e *AppError) IsClientError() bool {
|
||||
return e.StatusCode >= 400 && e.StatusCode < 500
|
||||
}
|
||||
|
||||
// IsServerError determines if an error is a server error (5xx)
|
||||
func (e *AppError) IsServerError() bool {
|
||||
return e.StatusCode >= 500
|
||||
}
|
||||
|
||||
// Common error constructors for frequently used errors
|
||||
|
||||
// NewUnauthorizedError creates an unauthorized error
|
||||
func NewUnauthorizedError(message string) *AppError {
|
||||
return New(ErrUnauthorized, message)
|
||||
}
|
||||
|
||||
// NewForbiddenError creates a forbidden error
|
||||
func NewForbiddenError(message string) *AppError {
|
||||
return New(ErrForbidden, message)
|
||||
}
|
||||
|
||||
// NewValidationError creates a validation error
|
||||
func NewValidationError(message string) *AppError {
|
||||
return New(ErrValidationFailed, message)
|
||||
}
|
||||
|
||||
// NewNotFoundError creates a not found error
|
||||
func NewNotFoundError(resource string) *AppError {
|
||||
return New(ErrNotFound, fmt.Sprintf("%s not found", resource))
|
||||
}
|
||||
|
||||
// NewAlreadyExistsError creates an already exists error
|
||||
func NewAlreadyExistsError(resource string) *AppError {
|
||||
return New(ErrAlreadyExists, fmt.Sprintf("%s already exists", resource))
|
||||
}
|
||||
|
||||
// NewInternalError creates an internal server error
|
||||
func NewInternalError(message string) *AppError {
|
||||
return New(ErrInternal, message)
|
||||
}
|
||||
|
||||
// NewDatabaseError creates a database error
|
||||
func NewDatabaseError(operation string, err error) *AppError {
|
||||
return Wrap(err, ErrDatabase, fmt.Sprintf("Database operation failed: %s", operation))
|
||||
}
|
||||
|
||||
// NewTokenError creates a token-related error
|
||||
func NewTokenError(message string) *AppError {
|
||||
return New(ErrInvalidToken, message)
|
||||
}
|
||||
|
||||
// NewApplicationError creates an application-related error
|
||||
func NewApplicationError(message string) *AppError {
|
||||
return New(ErrInvalidApplication, message)
|
||||
}
|
||||
|
||||
// NewPermissionError creates a permission-related error
|
||||
func NewPermissionError(message string) *AppError {
|
||||
return New(ErrInsufficientPermissions, message)
|
||||
}
|
||||
|
||||
// ErrorResponse represents the JSON error response format
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Code ErrorCode `json:"code"`
|
||||
Details string `json:"details,omitempty"`
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
}
|
||||
|
||||
// ToResponse converts an AppError to an ErrorResponse
|
||||
func (e *AppError) ToResponse() ErrorResponse {
|
||||
return ErrorResponse{
|
||||
Error: string(e.Code),
|
||||
Message: e.Message,
|
||||
Code: e.Code,
|
||||
Details: e.Details,
|
||||
Context: e.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Recovery handles panic recovery and converts to appropriate errors
|
||||
func Recovery(recovered interface{}) *AppError {
|
||||
switch v := recovered.(type) {
|
||||
case *AppError:
|
||||
return v
|
||||
case error:
|
||||
return Wrap(v, ErrInternal, "Internal server error occurred")
|
||||
case string:
|
||||
return New(ErrInternal, v)
|
||||
default:
|
||||
return New(ErrInternal, "Unknown internal error occurred")
|
||||
}
|
||||
}
|
||||
|
||||
// Chain represents a chain of errors for better error tracking
|
||||
type Chain struct {
|
||||
errors []*AppError
|
||||
}
|
||||
|
||||
// NewChain creates a new error chain
|
||||
func NewChain() *Chain {
|
||||
return &Chain{
|
||||
errors: make([]*AppError, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds an error to the chain
|
||||
func (c *Chain) Add(err *AppError) *Chain {
|
||||
c.errors = append(c.errors, err)
|
||||
return c
|
||||
}
|
||||
|
||||
// HasErrors returns true if the chain has any errors
|
||||
func (c *Chain) HasErrors() bool {
|
||||
return len(c.errors) > 0
|
||||
}
|
||||
|
||||
// First returns the first error in the chain
|
||||
func (c *Chain) First() *AppError {
|
||||
if len(c.errors) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.errors[0]
|
||||
}
|
||||
|
||||
// Last returns the last error in the chain
|
||||
func (c *Chain) Last() *AppError {
|
||||
if len(c.errors) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.errors[len(c.errors)-1]
|
||||
}
|
||||
|
||||
// All returns all errors in the chain
|
||||
func (c *Chain) All() []*AppError {
|
||||
return c.errors
|
||||
}
|
||||
|
||||
// Error implements the error interface for the chain
|
||||
func (c *Chain) Error() string {
|
||||
if len(c.errors) == 0 {
|
||||
return "no errors"
|
||||
}
|
||||
if len(c.errors) == 1 {
|
||||
return c.errors[0].Error()
|
||||
}
|
||||
return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1)
|
||||
}
|
||||
415
internal/metrics/metrics.go
Normal file
415
internal/metrics/metrics.go
Normal file
@ -0,0 +1,415 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Metrics holds all application metrics
|
||||
type Metrics struct {
|
||||
// HTTP metrics
|
||||
RequestsTotal *Counter
|
||||
RequestDuration *Histogram
|
||||
RequestsInFlight *Gauge
|
||||
ResponseSize *Histogram
|
||||
|
||||
// Business metrics
|
||||
TokensCreated *Counter
|
||||
TokensVerified *Counter
|
||||
TokensRevoked *Counter
|
||||
ApplicationsTotal *Gauge
|
||||
PermissionsTotal *Gauge
|
||||
|
||||
// System metrics
|
||||
DatabaseConnections *Gauge
|
||||
DatabaseQueries *Counter
|
||||
DatabaseErrors *Counter
|
||||
CacheHits *Counter
|
||||
CacheMisses *Counter
|
||||
|
||||
// Error metrics
|
||||
ErrorsTotal *Counter
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Counter represents a monotonically increasing counter
|
||||
type Counter struct {
|
||||
value float64
|
||||
labels map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Gauge represents a value that can go up and down
|
||||
type Gauge struct {
|
||||
value float64
|
||||
labels map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Histogram represents a distribution of values
|
||||
type Histogram struct {
|
||||
buckets map[float64]float64
|
||||
sum float64
|
||||
count float64
|
||||
labels map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMetrics creates a new metrics instance
|
||||
func NewMetrics() *Metrics {
|
||||
return &Metrics{
|
||||
// HTTP metrics
|
||||
RequestsTotal: NewCounter("http_requests_total", map[string]string{}),
|
||||
RequestDuration: NewHistogram("http_request_duration_seconds", map[string]string{}),
|
||||
RequestsInFlight: NewGauge("http_requests_in_flight", map[string]string{}),
|
||||
ResponseSize: NewHistogram("http_response_size_bytes", map[string]string{}),
|
||||
|
||||
// Business metrics
|
||||
TokensCreated: NewCounter("tokens_created_total", map[string]string{}),
|
||||
TokensVerified: NewCounter("tokens_verified_total", map[string]string{}),
|
||||
TokensRevoked: NewCounter("tokens_revoked_total", map[string]string{}),
|
||||
ApplicationsTotal: NewGauge("applications_total", map[string]string{}),
|
||||
PermissionsTotal: NewGauge("permissions_total", map[string]string{}),
|
||||
|
||||
// System metrics
|
||||
DatabaseConnections: NewGauge("database_connections", map[string]string{}),
|
||||
DatabaseQueries: NewCounter("database_queries_total", map[string]string{}),
|
||||
DatabaseErrors: NewCounter("database_errors_total", map[string]string{}),
|
||||
CacheHits: NewCounter("cache_hits_total", map[string]string{}),
|
||||
CacheMisses: NewCounter("cache_misses_total", map[string]string{}),
|
||||
|
||||
// Error metrics
|
||||
ErrorsTotal: NewCounter("errors_total", map[string]string{}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewCounter creates a new counter
|
||||
func NewCounter(name string, labels map[string]string) *Counter {
|
||||
return &Counter{
|
||||
value: 0,
|
||||
labels: labels,
|
||||
}
|
||||
}
|
||||
|
||||
// NewGauge creates a new gauge
|
||||
func NewGauge(name string, labels map[string]string) *Gauge {
|
||||
return &Gauge{
|
||||
value: 0,
|
||||
labels: labels,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHistogram creates a new histogram
|
||||
func NewHistogram(name string, labels map[string]string) *Histogram {
|
||||
return &Histogram{
|
||||
buckets: make(map[float64]float64),
|
||||
sum: 0,
|
||||
count: 0,
|
||||
labels: labels,
|
||||
}
|
||||
}
|
||||
|
||||
// Counter methods
|
||||
func (c *Counter) Inc() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.value++
|
||||
}
|
||||
|
||||
func (c *Counter) Add(value float64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.value += value
|
||||
}
|
||||
|
||||
func (c *Counter) Value() float64 {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.value
|
||||
}
|
||||
|
||||
// Gauge methods
|
||||
func (g *Gauge) Set(value float64) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.value = value
|
||||
}
|
||||
|
||||
func (g *Gauge) Inc() {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.value++
|
||||
}
|
||||
|
||||
func (g *Gauge) Dec() {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.value--
|
||||
}
|
||||
|
||||
func (g *Gauge) Add(value float64) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.value += value
|
||||
}
|
||||
|
||||
func (g *Gauge) Value() float64 {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return g.value
|
||||
}
|
||||
|
||||
// Histogram methods
|
||||
func (h *Histogram) Observe(value float64) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.sum += value
|
||||
h.count++
|
||||
|
||||
// Define standard buckets
|
||||
buckets := []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
|
||||
for _, bucket := range buckets {
|
||||
if value <= bucket {
|
||||
h.buckets[bucket]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Histogram) Sum() float64 {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.sum
|
||||
}
|
||||
|
||||
func (h *Histogram) Count() float64 {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.count
|
||||
}
|
||||
|
||||
func (h *Histogram) Buckets() map[float64]float64 {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
result := make(map[float64]float64)
|
||||
for k, v := range h.buckets {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Global metrics instance
|
||||
var globalMetrics *Metrics
|
||||
var once sync.Once
|
||||
|
||||
// GetMetrics returns the global metrics instance
|
||||
func GetMetrics() *Metrics {
|
||||
once.Do(func() {
|
||||
globalMetrics = NewMetrics()
|
||||
})
|
||||
return globalMetrics
|
||||
}
|
||||
|
||||
// Middleware creates a Gin middleware for collecting HTTP metrics
|
||||
func Middleware(logger *zap.Logger) gin.HandlerFunc {
|
||||
metrics := GetMetrics()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
// Increment in-flight requests
|
||||
metrics.RequestsInFlight.Inc()
|
||||
defer metrics.RequestsInFlight.Dec()
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(start).Seconds()
|
||||
status := strconv.Itoa(c.Writer.Status())
|
||||
method := c.Request.Method
|
||||
path := c.FullPath()
|
||||
|
||||
// Increment total requests
|
||||
metrics.RequestsTotal.Add(1)
|
||||
|
||||
// Record request duration
|
||||
metrics.RequestDuration.Observe(duration)
|
||||
|
||||
// Record response size
|
||||
metrics.ResponseSize.Observe(float64(c.Writer.Size()))
|
||||
|
||||
// Record errors
|
||||
if c.Writer.Status() >= 400 {
|
||||
metrics.ErrorsTotal.Add(1)
|
||||
}
|
||||
|
||||
// Log metrics
|
||||
logger.Debug("HTTP request metrics",
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.String("status", status),
|
||||
zap.Float64("duration", duration),
|
||||
zap.Int("size", c.Writer.Size()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordTokenCreation records a token creation event
|
||||
func RecordTokenCreation(tokenType string) {
|
||||
metrics := GetMetrics()
|
||||
metrics.TokensCreated.Inc()
|
||||
}
|
||||
|
||||
// RecordTokenVerification records a token verification event
|
||||
func RecordTokenVerification(tokenType string, success bool) {
|
||||
metrics := GetMetrics()
|
||||
metrics.TokensVerified.Inc()
|
||||
}
|
||||
|
||||
// RecordTokenRevocation records a token revocation event
|
||||
func RecordTokenRevocation(tokenType string) {
|
||||
metrics := GetMetrics()
|
||||
metrics.TokensRevoked.Inc()
|
||||
}
|
||||
|
||||
// RecordDatabaseQuery records a database query
|
||||
func RecordDatabaseQuery(operation string, success bool) {
|
||||
metrics := GetMetrics()
|
||||
metrics.DatabaseQueries.Inc()
|
||||
|
||||
if !success {
|
||||
metrics.DatabaseErrors.Inc()
|
||||
}
|
||||
}
|
||||
|
||||
// RecordCacheHit records a cache hit
|
||||
func RecordCacheHit() {
|
||||
metrics := GetMetrics()
|
||||
metrics.CacheHits.Inc()
|
||||
}
|
||||
|
||||
// RecordCacheMiss records a cache miss
|
||||
func RecordCacheMiss() {
|
||||
metrics := GetMetrics()
|
||||
metrics.CacheMisses.Inc()
|
||||
}
|
||||
|
||||
// UpdateApplicationCount updates the total number of applications
|
||||
func UpdateApplicationCount(count int) {
|
||||
metrics := GetMetrics()
|
||||
metrics.ApplicationsTotal.Set(float64(count))
|
||||
}
|
||||
|
||||
// UpdatePermissionCount updates the total number of permissions
|
||||
func UpdatePermissionCount(count int) {
|
||||
metrics := GetMetrics()
|
||||
metrics.PermissionsTotal.Set(float64(count))
|
||||
}
|
||||
|
||||
// UpdateDatabaseConnections updates the number of database connections
|
||||
func UpdateDatabaseConnections(count int) {
|
||||
metrics := GetMetrics()
|
||||
metrics.DatabaseConnections.Set(float64(count))
|
||||
}
|
||||
|
||||
// PrometheusHandler returns an HTTP handler that exports metrics in Prometheus format
|
||||
func PrometheusHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
metrics := GetMetrics()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
|
||||
|
||||
// Export all metrics in Prometheus format
|
||||
exportCounter(w, "http_requests_total", metrics.RequestsTotal)
|
||||
exportGauge(w, "http_requests_in_flight", metrics.RequestsInFlight)
|
||||
exportHistogram(w, "http_request_duration_seconds", metrics.RequestDuration)
|
||||
exportHistogram(w, "http_response_size_bytes", metrics.ResponseSize)
|
||||
|
||||
exportCounter(w, "tokens_created_total", metrics.TokensCreated)
|
||||
exportCounter(w, "tokens_verified_total", metrics.TokensVerified)
|
||||
exportCounter(w, "tokens_revoked_total", metrics.TokensRevoked)
|
||||
exportGauge(w, "applications_total", metrics.ApplicationsTotal)
|
||||
exportGauge(w, "permissions_total", metrics.PermissionsTotal)
|
||||
|
||||
exportGauge(w, "database_connections", metrics.DatabaseConnections)
|
||||
exportCounter(w, "database_queries_total", metrics.DatabaseQueries)
|
||||
exportCounter(w, "database_errors_total", metrics.DatabaseErrors)
|
||||
exportCounter(w, "cache_hits_total", metrics.CacheHits)
|
||||
exportCounter(w, "cache_misses_total", metrics.CacheMisses)
|
||||
|
||||
exportCounter(w, "errors_total", metrics.ErrorsTotal)
|
||||
}
|
||||
}
|
||||
|
||||
func exportCounter(w http.ResponseWriter, name string, counter *Counter) {
|
||||
w.Write([]byte("# HELP " + name + " Total number of " + name + "\n"))
|
||||
w.Write([]byte("# TYPE " + name + " counter\n"))
|
||||
w.Write([]byte(name + " " + strconv.FormatFloat(counter.Value(), 'f', -1, 64) + "\n"))
|
||||
}
|
||||
|
||||
func exportGauge(w http.ResponseWriter, name string, gauge *Gauge) {
|
||||
w.Write([]byte("# HELP " + name + " Current value of " + name + "\n"))
|
||||
w.Write([]byte("# TYPE " + name + " gauge\n"))
|
||||
w.Write([]byte(name + " " + strconv.FormatFloat(gauge.Value(), 'f', -1, 64) + "\n"))
|
||||
}
|
||||
|
||||
func exportHistogram(w http.ResponseWriter, name string, histogram *Histogram) {
|
||||
w.Write([]byte("# HELP " + name + " Histogram of " + name + "\n"))
|
||||
w.Write([]byte("# TYPE " + name + " histogram\n"))
|
||||
|
||||
buckets := histogram.Buckets()
|
||||
for bucket, count := range buckets {
|
||||
w.Write([]byte(name + "_bucket{le=\"" + strconv.FormatFloat(bucket, 'f', -1, 64) + "\"} " + strconv.FormatFloat(count, 'f', -1, 64) + "\n"))
|
||||
}
|
||||
|
||||
w.Write([]byte(name + "_sum " + strconv.FormatFloat(histogram.Sum(), 'f', -1, 64) + "\n"))
|
||||
w.Write([]byte(name + "_count " + strconv.FormatFloat(histogram.Count(), 'f', -1, 64) + "\n"))
|
||||
}
|
||||
|
||||
// HealthMetrics represents health check metrics
|
||||
type HealthMetrics struct {
|
||||
DatabaseConnected bool `json:"database_connected"`
|
||||
ResponseTime time.Duration `json:"response_time"`
|
||||
Uptime time.Duration `json:"uptime"`
|
||||
Version string `json:"version"`
|
||||
Environment string `json:"environment"`
|
||||
}
|
||||
|
||||
// GetHealthMetrics returns current health metrics
|
||||
func GetHealthMetrics(ctx context.Context, version, environment string, startTime time.Time) *HealthMetrics {
|
||||
return &HealthMetrics{
|
||||
DatabaseConnected: true, // This should be checked against actual DB
|
||||
ResponseTime: time.Since(time.Now()),
|
||||
Uptime: time.Since(startTime),
|
||||
Version: version,
|
||||
Environment: environment,
|
||||
}
|
||||
}
|
||||
|
||||
// BusinessMetrics represents business-specific metrics
|
||||
type BusinessMetrics struct {
|
||||
TotalApplications int `json:"total_applications"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
TotalPermissions int `json:"total_permissions"`
|
||||
ActiveTokens int `json:"active_tokens"`
|
||||
}
|
||||
|
||||
// GetBusinessMetrics returns current business metrics
|
||||
func GetBusinessMetrics() *BusinessMetrics {
|
||||
metrics := GetMetrics()
|
||||
|
||||
return &BusinessMetrics{
|
||||
TotalApplications: int(metrics.ApplicationsTotal.Value()),
|
||||
TotalTokens: int(metrics.TokensCreated.Value()),
|
||||
TotalPermissions: int(metrics.PermissionsTotal.Value()),
|
||||
ActiveTokens: int(metrics.TokensCreated.Value() - metrics.TokensRevoked.Value()),
|
||||
}
|
||||
}
|
||||
265
internal/middleware/validation.go
Normal file
265
internal/middleware/validation.go
Normal file
@ -0,0 +1,265 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ValidationError represents a validation error
|
||||
type ValidationError struct {
|
||||
Field string `json:"field"`
|
||||
Tag string `json:"tag"`
|
||||
Value string `json:"value"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ValidationResponse represents the validation error response
|
||||
type ValidationResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Details []ValidationError `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
var validate *validator.Validate
|
||||
|
||||
func init() {
|
||||
validate = validator.New()
|
||||
|
||||
// Register custom tag name function to use json tags
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateJSON validates JSON request body against struct validation tags
|
||||
func ValidateJSON(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip validation for GET requests and requests without body
|
||||
if c.Request.Method == "GET" || c.Request.ContentLength == 0 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Store original body for potential re-reading
|
||||
c.Set("validation_enabled", true)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateStruct validates a struct and returns formatted errors
|
||||
func ValidateStruct(s interface{}) []ValidationError {
|
||||
var errors []ValidationError
|
||||
|
||||
err := validate.Struct(s)
|
||||
if err != nil {
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
var element ValidationError
|
||||
element.Field = err.Field()
|
||||
element.Tag = err.Tag()
|
||||
element.Value = err.Param()
|
||||
element.Message = getErrorMessage(err)
|
||||
errors = append(errors, element)
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// ValidateAndBind validates and binds JSON request to struct
|
||||
func ValidateAndBind(c *gin.Context, obj interface{}) error {
|
||||
// Bind JSON to struct
|
||||
if err := c.ShouldBindJSON(obj); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid JSON",
|
||||
Message: "Request body contains invalid JSON: " + err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate struct
|
||||
if validationErrors := ValidateStruct(obj); len(validationErrors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Validation Failed",
|
||||
Message: "Request validation failed",
|
||||
Details: validationErrors,
|
||||
})
|
||||
return validator.ValidationErrors{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getErrorMessage returns a human-readable error message for validation errors
|
||||
func getErrorMessage(fe validator.FieldError) string {
|
||||
switch fe.Tag() {
|
||||
case "required":
|
||||
return "This field is required"
|
||||
case "email":
|
||||
return "Invalid email format"
|
||||
case "min":
|
||||
return "Value is too short (minimum " + fe.Param() + " characters)"
|
||||
case "max":
|
||||
return "Value is too long (maximum " + fe.Param() + " characters)"
|
||||
case "url":
|
||||
return "Invalid URL format"
|
||||
case "oneof":
|
||||
return "Value must be one of: " + fe.Param()
|
||||
case "uuid":
|
||||
return "Invalid UUID format"
|
||||
case "gte":
|
||||
return "Value must be greater than or equal to " + fe.Param()
|
||||
case "lte":
|
||||
return "Value must be less than or equal to " + fe.Param()
|
||||
case "len":
|
||||
return "Value must be exactly " + fe.Param() + " characters"
|
||||
case "dive":
|
||||
return "Invalid array element"
|
||||
default:
|
||||
return "Invalid value for " + fe.Field()
|
||||
}
|
||||
}
|
||||
|
||||
// RequiredFields validates that specific fields are present in the request
|
||||
func RequiredFields(fields ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var json map[string]interface{}
|
||||
|
||||
if err := c.ShouldBindJSON(&json); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid JSON",
|
||||
Message: "Request body contains invalid JSON",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
var missingFields []string
|
||||
for _, field := range fields {
|
||||
if _, exists := json[field]; !exists {
|
||||
missingFields = append(missingFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingFields) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Missing Required Fields",
|
||||
Message: "The following required fields are missing: " + strings.Join(missingFields, ", "),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Store the parsed JSON for use in handlers
|
||||
c.Set("parsed_json", json)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUUID validates that a URL parameter is a valid UUID
|
||||
func ValidateUUID(param string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
value := c.Param(param)
|
||||
if value == "" {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Missing Parameter",
|
||||
Message: "Required parameter '" + param + "' is missing",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate UUID format
|
||||
if err := validate.Var(value, "uuid"); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid Parameter",
|
||||
Message: "Parameter '" + param + "' must be a valid UUID",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQueryParams validates query parameters
|
||||
func ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var errors []ValidationError
|
||||
|
||||
for param, rule := range rules {
|
||||
value := c.Query(param)
|
||||
if value != "" {
|
||||
if err := validate.Var(value, rule); err != nil {
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: param,
|
||||
Tag: err.Tag(),
|
||||
Value: err.Param(),
|
||||
Message: getErrorMessage(err),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||
Error: "Invalid Query Parameters",
|
||||
Message: "One or more query parameters are invalid",
|
||||
Details: errors,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeInput sanitizes input strings to prevent XSS and injection attacks
|
||||
func SanitizeInput() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// This is a basic implementation - in production you might want to use
|
||||
// a more sophisticated sanitization library like bluemonday
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePermissions validates that permission scopes follow the expected format
|
||||
func ValidatePermissions(c *gin.Context, permissions []string) []ValidationError {
|
||||
var errors []ValidationError
|
||||
|
||||
for i, perm := range permissions {
|
||||
// Check basic format: should contain only alphanumeric, dots, and underscores
|
||||
if err := validate.Var(perm, "required,min=1,max=255,alphanum|contains=.|contains=_"); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "permissions[" + string(rune(i)) + "]",
|
||||
Tag: "format",
|
||||
Value: perm,
|
||||
Message: "Permission scope must contain only alphanumeric characters, dots, and underscores",
|
||||
})
|
||||
}
|
||||
|
||||
// Check for dangerous patterns
|
||||
if strings.Contains(perm, "..") || strings.HasPrefix(perm, ".") || strings.HasSuffix(perm, ".") {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "permissions[" + string(rune(i)) + "]",
|
||||
Tag: "format",
|
||||
Value: perm,
|
||||
Message: "Permission scope has invalid format",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
@ -2,10 +2,14 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// PermissionRepository implements the PermissionRepository interface for PostgreSQL
|
||||
@ -20,20 +24,116 @@ func NewPermissionRepository(db repository.DatabaseProvider) repository.Permissi
|
||||
|
||||
// CreateAvailablePermission creates a new available permission
|
||||
func (r *PermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error {
|
||||
// TODO: Implement actual permission creation
|
||||
query := `
|
||||
INSERT INTO available_permissions (
|
||||
id, scope, name, description, category, parent_scope,
|
||||
is_system, created_by, updated_by, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
now := time.Now()
|
||||
|
||||
if permission.ID == uuid.Nil {
|
||||
permission.ID = uuid.New()
|
||||
}
|
||||
|
||||
_, err := db.ExecContext(ctx, query,
|
||||
permission.ID,
|
||||
permission.Scope,
|
||||
permission.Name,
|
||||
permission.Description,
|
||||
permission.Category,
|
||||
permission.ParentScope,
|
||||
permission.IsSystem,
|
||||
permission.CreatedBy,
|
||||
permission.UpdatedBy,
|
||||
now,
|
||||
now,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create available permission: %w", err)
|
||||
}
|
||||
|
||||
permission.CreatedAt = now
|
||||
permission.UpdatedAt = now
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailablePermission retrieves an available permission by ID
|
||||
func (r *PermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) {
|
||||
// TODO: Implement actual permission retrieval
|
||||
return nil, nil
|
||||
query := `
|
||||
SELECT id, scope, name, description, category, parent_scope,
|
||||
is_system, created_at, created_by, updated_at, updated_by
|
||||
FROM available_permissions
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, permissionID)
|
||||
|
||||
permission := &domain.AvailablePermission{}
|
||||
err := row.Scan(
|
||||
&permission.ID,
|
||||
&permission.Scope,
|
||||
&permission.Name,
|
||||
&permission.Description,
|
||||
&permission.Category,
|
||||
&permission.ParentScope,
|
||||
&permission.IsSystem,
|
||||
&permission.CreatedAt,
|
||||
&permission.CreatedBy,
|
||||
&permission.UpdatedAt,
|
||||
&permission.UpdatedBy,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("permission with ID '%s' not found", permissionID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get available permission: %w", err)
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
|
||||
// GetAvailablePermissionByScope retrieves an available permission by scope
|
||||
func (r *PermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) {
|
||||
// TODO: Implement actual permission retrieval by scope
|
||||
return nil, nil
|
||||
query := `
|
||||
SELECT id, scope, name, description, category, parent_scope,
|
||||
is_system, created_at, created_by, updated_at, updated_by
|
||||
FROM available_permissions
|
||||
WHERE scope = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, scope)
|
||||
|
||||
permission := &domain.AvailablePermission{}
|
||||
err := row.Scan(
|
||||
&permission.ID,
|
||||
&permission.Scope,
|
||||
&permission.Name,
|
||||
&permission.Description,
|
||||
&permission.Category,
|
||||
&permission.ParentScope,
|
||||
&permission.IsSystem,
|
||||
&permission.CreatedAt,
|
||||
&permission.CreatedBy,
|
||||
&permission.UpdatedAt,
|
||||
&permission.UpdatedBy,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("permission with scope '%s' not found", scope)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get available permission by scope: %w", err)
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
|
||||
// ListAvailablePermissions retrieves available permissions with pagination and filtering
|
||||
@ -56,9 +156,44 @@ func (r *PermissionRepository) DeleteAvailablePermission(ctx context.Context, pe
|
||||
|
||||
// ValidatePermissionScopes checks if all given scopes exist and are valid
|
||||
func (r *PermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) {
|
||||
// TODO: Implement actual scope validation
|
||||
// For now, assume all scopes are valid
|
||||
return []string{}, nil
|
||||
if len(scopes) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT scope
|
||||
FROM available_permissions
|
||||
WHERE scope = ANY($1)
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, pq.Array(scopes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate permission scopes: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
validScopes := make(map[string]bool)
|
||||
for rows.Next() {
|
||||
var scope string
|
||||
if err := rows.Scan(&scope); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan scope: %w", err)
|
||||
}
|
||||
validScopes[scope] = true
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating scopes: %w", err)
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, scope := range scopes {
|
||||
if validScopes[scope] {
|
||||
result = append(result, scope)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetPermissionHierarchy returns all parent and child permissions for given scopes
|
||||
@ -79,7 +214,56 @@ func NewGrantedPermissionRepository(db repository.DatabaseProvider) repository.G
|
||||
|
||||
// GrantPermissions grants multiple permissions to a token
|
||||
func (r *GrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error {
|
||||
// TODO: Implement actual permission granting
|
||||
if len(grants) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
query := `
|
||||
INSERT INTO granted_permissions (
|
||||
id, token_type, token_id, permission_id, scope, created_by, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (token_type, token_id, permission_id) DO NOTHING
|
||||
`
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare statement: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
now := time.Now()
|
||||
for _, grant := range grants {
|
||||
if grant.ID == uuid.Nil {
|
||||
grant.ID = uuid.New()
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx,
|
||||
grant.ID,
|
||||
string(grant.TokenType),
|
||||
grant.TokenID,
|
||||
grant.PermissionID,
|
||||
grant.Scope,
|
||||
grant.CreatedBy,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to grant permission: %w", err)
|
||||
}
|
||||
|
||||
grant.CreatedAt = now
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -109,16 +293,72 @@ func (r *GrantedPermissionRepository) RevokeAllPermissions(ctx context.Context,
|
||||
|
||||
// HasPermission checks if a token has a specific permission
|
||||
func (r *GrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) {
|
||||
// TODO: Implement actual permission checking
|
||||
query := `
|
||||
SELECT 1
|
||||
FROM granted_permissions gp
|
||||
JOIN available_permissions ap ON gp.permission_id = ap.id
|
||||
WHERE gp.token_type = $1
|
||||
AND gp.token_id = $2
|
||||
AND gp.scope = $3
|
||||
AND gp.revoked = false
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
var exists int
|
||||
err := db.QueryRowContext(ctx, query, string(tokenType), tokenID, scope).Scan(&exists)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check permission: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// HasAnyPermission checks if a token has any of the specified permissions
|
||||
func (r *GrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) {
|
||||
// TODO: Implement actual permission checking
|
||||
if len(scopes) == 0 {
|
||||
return make(map[string]bool), nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT gp.scope
|
||||
FROM granted_permissions gp
|
||||
JOIN available_permissions ap ON gp.permission_id = ap.id
|
||||
WHERE gp.token_type = $1
|
||||
AND gp.token_id = $2
|
||||
AND gp.scope = ANY($3)
|
||||
AND gp.revoked = false
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID, pq.Array(scopes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check permissions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[string]bool)
|
||||
// Initialize all scopes as false
|
||||
for _, scope := range scopes {
|
||||
result[scope] = false
|
||||
}
|
||||
|
||||
// Mark found permissions as true
|
||||
for rows.Next() {
|
||||
var scope string
|
||||
if err := rows.Scan(&scope); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan permission scope: %w", err)
|
||||
}
|
||||
result[scope] = true
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating permission results: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -57,26 +57,196 @@ func (r *StaticTokenRepository) Create(ctx context.Context, token *domain.Static
|
||||
|
||||
// GetByID retrieves a static token by its ID
|
||||
func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
|
||||
// TODO: Implement actual token retrieval
|
||||
return nil, nil
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, tokenID)
|
||||
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := row.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("static token with ID '%s' not found", tokenID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get static token: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetByKeyHash retrieves a static token by its key hash
|
||||
func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
|
||||
// TODO: Implement actual token retrieval by hash
|
||||
return nil, nil
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
WHERE key_hash = $1
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
row := db.QueryRowContext(ctx, query, keyHash)
|
||||
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := row.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("static token with hash not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get static token by hash: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetByAppID retrieves all static tokens for an application
|
||||
func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
|
||||
// TODO: Implement actual token listing
|
||||
return []*domain.StaticToken{}, nil
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
WHERE app_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, appID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []*domain.StaticToken
|
||||
for rows.Next() {
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// List retrieves static tokens with pagination
|
||||
func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
|
||||
// TODO: Implement actual token listing
|
||||
return []*domain.StaticToken{}, nil
|
||||
query := `
|
||||
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||
key_hash, type, created_at, updated_at
|
||||
FROM static_tokens
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
`
|
||||
|
||||
db := r.db.GetDB().(*sql.DB)
|
||||
rows, err := db.QueryContext(ctx, query, limit, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tokens []*domain.StaticToken
|
||||
for rows.Next() {
|
||||
token := &domain.StaticToken{}
|
||||
var ownerType, ownerName, ownerOwner string
|
||||
|
||||
err := rows.Scan(
|
||||
&token.ID,
|
||||
&token.AppID,
|
||||
&ownerType,
|
||||
&ownerName,
|
||||
&ownerOwner,
|
||||
&token.KeyHash,
|
||||
&token.Type,
|
||||
&token.CreatedAt,
|
||||
&token.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
||||
}
|
||||
|
||||
token.Owner = domain.Owner{
|
||||
Type: domain.OwnerType(ownerType),
|
||||
Name: ownerName,
|
||||
Owner: ownerOwner,
|
||||
}
|
||||
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// Delete deletes a static token
|
||||
|
||||
@ -8,17 +8,19 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/kms/api-key-service/internal/crypto"
|
||||
"github.com/kms/api-key-service/internal/domain"
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
)
|
||||
|
||||
// tokenService implements the TokenService interface
|
||||
type tokenService struct {
|
||||
tokenRepo repository.StaticTokenRepository
|
||||
appRepo repository.ApplicationRepository
|
||||
permRepo repository.PermissionRepository
|
||||
grantRepo repository.GrantedPermissionRepository
|
||||
logger *zap.Logger
|
||||
tokenRepo repository.StaticTokenRepository
|
||||
appRepo repository.ApplicationRepository
|
||||
permRepo repository.PermissionRepository
|
||||
grantRepo repository.GrantedPermissionRepository
|
||||
tokenGen *crypto.TokenGenerator
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTokenService creates a new token service
|
||||
@ -27,6 +29,7 @@ func NewTokenService(
|
||||
appRepo repository.ApplicationRepository,
|
||||
permRepo repository.PermissionRepository,
|
||||
grantRepo repository.GrantedPermissionRepository,
|
||||
hmacKey string,
|
||||
logger *zap.Logger,
|
||||
) TokenService {
|
||||
return &tokenService{
|
||||
@ -34,6 +37,7 @@ func NewTokenService(
|
||||
appRepo: appRepo,
|
||||
permRepo: permRepo,
|
||||
grantRepo: grantRepo,
|
||||
tokenGen: crypto.NewTokenGenerator(hmacKey),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@ -42,10 +46,33 @@ func NewTokenService(
|
||||
func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.CreateStaticTokenRequest, userID string) (*domain.CreateStaticTokenResponse, error) {
|
||||
s.logger.Info("Creating static token", zap.String("app_id", req.AppID), zap.String("user_id", userID))
|
||||
|
||||
// TODO: Validate permissions
|
||||
// TODO: Validate application exists
|
||||
// TODO: Generate secure token
|
||||
// TODO: Grant permissions
|
||||
// Validate application exists
|
||||
app, err := s.appRepo.GetByID(ctx, req.AppID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", req.AppID))
|
||||
return nil, fmt.Errorf("application not found: %w", err)
|
||||
}
|
||||
|
||||
// Validate permissions exist
|
||||
validPermissions, err := s.permRepo.ValidatePermissionScopes(ctx, req.Permissions)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to validate permissions", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to validate permissions: %w", err)
|
||||
}
|
||||
|
||||
if len(validPermissions) != len(req.Permissions) {
|
||||
s.logger.Warn("Some permissions are invalid",
|
||||
zap.Strings("requested", req.Permissions),
|
||||
zap.Strings("valid", validPermissions))
|
||||
return nil, fmt.Errorf("some requested permissions are invalid")
|
||||
}
|
||||
|
||||
// Generate secure token
|
||||
tokenInfo, err := s.tokenGen.GenerateTokenWithInfo()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to generate secure token", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
tokenID := uuid.New()
|
||||
now := time.Now()
|
||||
@ -54,32 +81,62 @@ func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.Create
|
||||
token := &domain.StaticToken{
|
||||
ID: tokenID,
|
||||
AppID: req.AppID,
|
||||
Owner: domain.Owner{
|
||||
Type: domain.OwnerTypeIndividual,
|
||||
Name: userID,
|
||||
Owner: userID,
|
||||
},
|
||||
KeyHash: "placeholder-hash-" + tokenID.String(),
|
||||
Type: "hmac",
|
||||
Owner: req.Owner,
|
||||
KeyHash: tokenInfo.Hash,
|
||||
Type: "hmac",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
// Save the token to the database
|
||||
err := s.tokenRepo.Create(ctx, token)
|
||||
err = s.tokenRepo.Create(ctx, token)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to create token in database", zap.Error(err), zap.String("token_id", tokenID.String()))
|
||||
return nil, fmt.Errorf("failed to create token: %w", err)
|
||||
}
|
||||
|
||||
// Grant permissions to the token
|
||||
var grants []*domain.GrantedPermission
|
||||
for _, permScope := range validPermissions {
|
||||
// Get permission by scope to get the ID
|
||||
perm, err := s.permRepo.GetAvailablePermissionByScope(ctx, permScope)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get permission by scope", zap.Error(err), zap.String("scope", permScope))
|
||||
continue
|
||||
}
|
||||
|
||||
grant := &domain.GrantedPermission{
|
||||
ID: uuid.New(),
|
||||
TokenType: domain.TokenTypeStatic,
|
||||
TokenID: tokenID,
|
||||
PermissionID: perm.ID,
|
||||
Scope: permScope,
|
||||
CreatedBy: userID,
|
||||
}
|
||||
grants = append(grants, grant)
|
||||
}
|
||||
|
||||
if len(grants) > 0 {
|
||||
err = s.grantRepo.GrantPermissions(ctx, grants)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to grant permissions", zap.Error(err))
|
||||
// Clean up the token if permission granting fails
|
||||
s.tokenRepo.Delete(ctx, tokenID)
|
||||
return nil, fmt.Errorf("failed to grant permissions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
response := &domain.CreateStaticTokenResponse{
|
||||
ID: tokenID,
|
||||
Token: "static-token-placeholder-" + tokenID.String(),
|
||||
Permissions: req.Permissions,
|
||||
Token: tokenInfo.Token, // Return the actual token only once
|
||||
Permissions: validPermissions,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
s.logger.Info("Static token created successfully", zap.String("token_id", tokenID.String()))
|
||||
s.logger.Info("Static token created successfully",
|
||||
zap.String("token_id", tokenID.String()),
|
||||
zap.String("app_id", app.AppID),
|
||||
zap.Strings("permissions", validPermissions))
|
||||
return response, nil
|
||||
}
|
||||
|
||||
|
||||
@ -267,10 +267,12 @@ test_token_endpoints() {
|
||||
-H "X-User-Email: $USER_EMAIL" \
|
||||
-H "X-User-ID: $USER_ID" \
|
||||
-d '{
|
||||
"name": "Test Static Token for Deletion",
|
||||
"description": "A test static token for deletion test",
|
||||
"permissions": ["read"],
|
||||
"expires_at": "2025-12-31T23:59:59Z"
|
||||
"owner": {
|
||||
"type": "individual",
|
||||
"name": "Test Token Owner",
|
||||
"owner": "test-token@example.com"
|
||||
},
|
||||
"permissions": ["repo.read", "repo.write"]
|
||||
}' 2>/dev/null || echo -e "\n000")
|
||||
|
||||
local token_status_code=$(echo "$token_response" | tail -n1)
|
||||
@ -282,10 +284,12 @@ test_token_endpoints() {
|
||||
-H "X-User-Email: $USER_EMAIL" \
|
||||
-H "X-User-ID: $USER_ID" \
|
||||
-d '{
|
||||
"name": "Test Static Token",
|
||||
"description": "A test static token",
|
||||
"permissions": ["read"],
|
||||
"expires_at": "2025-12-31T23:59:59Z"
|
||||
"owner": {
|
||||
"type": "individual",
|
||||
"name": "Test Token Owner",
|
||||
"owner": "test-token@example.com"
|
||||
},
|
||||
"permissions": ["repo.read", "repo.write"]
|
||||
}'
|
||||
|
||||
# Extract token_id from the first response for deletion test
|
||||
|
||||
@ -96,7 +96,7 @@ func (suite *IntegrationTestSuite) setupServer() {
|
||||
|
||||
// Initialize services
|
||||
appService := services.NewApplicationService(appRepo, logger)
|
||||
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger)
|
||||
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, suite.cfg.GetString("INTERNAL_HMAC_KEY"), logger)
|
||||
authService := services.NewAuthenticationService(suite.cfg, logger)
|
||||
|
||||
// Initialize handlers
|
||||
|
||||
@ -460,14 +460,14 @@ func (m *MockPermissionRepository) ValidatePermissionScopes(ctx context.Context,
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var invalid []string
|
||||
var valid []string
|
||||
for _, scope := range scopes {
|
||||
if _, exists := m.scopeIndex[scope]; !exists {
|
||||
invalid = append(invalid, scope)
|
||||
if _, exists := m.scopeIndex[scope]; exists {
|
||||
valid = append(valid, scope)
|
||||
}
|
||||
}
|
||||
|
||||
return invalid, nil
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
func (m *MockPermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {
|
||||
|
||||
Reference in New Issue
Block a user