From 141b1e936d9224f4679f5bce1b32929f59d61b5d Mon Sep 17 00:00:00 2001 From: Ryan Copley Date: Fri, 22 Aug 2025 14:40:59 -0400 Subject: [PATCH] - --- cmd/server/main.go | 22 +- docs/PRODUCTION_ROADMAP.md | 297 +++++++++++++ internal/crypto/token.go | 173 ++++++++ internal/errors/errors.go | 287 ++++++++++++ internal/metrics/metrics.go | 415 ++++++++++++++++++ internal/middleware/validation.go | 265 +++++++++++ .../postgres/permission_repository.go | 262 ++++++++++- .../repository/postgres/token_repository.go | 186 +++++++- internal/services/token_service.go | 97 +++- test/e2e_test.sh | 20 +- test/integration_test.go | 2 +- test/mock_repositories.go | 8 +- 12 files changed, 1973 insertions(+), 61 deletions(-) create mode 100644 docs/PRODUCTION_ROADMAP.md create mode 100644 internal/crypto/token.go create mode 100644 internal/errors/errors.go create mode 100644 internal/metrics/metrics.go create mode 100644 internal/middleware/validation.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 7bbf596..8e97f53 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -15,6 +15,7 @@ import ( "github.com/kms/api-key-service/internal/config" "github.com/kms/api-key-service/internal/database" "github.com/kms/api-key-service/internal/handlers" + "github.com/kms/api-key-service/internal/metrics" "github.com/kms/api-key-service/internal/middleware" "github.com/kms/api-key-service/internal/repository/postgres" "github.com/kms/api-key-service/internal/services" @@ -61,7 +62,7 @@ func main() { // Initialize services appService := services.NewApplicationService(appRepo, logger) - tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger) + tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, cfg.GetString("INTERNAL_HMAC_KEY"), logger) authService := services.NewAuthenticationService(cfg, logger) // Initialize handlers @@ -156,6 +157,7 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h // Add middleware router.Use(middleware.Logger(logger)) router.Use(middleware.Recovery(logger)) + router.Use(metrics.Middleware(logger)) router.Use(middleware.CORS()) router.Use(middleware.Security()) router.Use(middleware.ValidateContentType()) @@ -226,18 +228,20 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h } func startMetricsServer(cfg config.ConfigProvider, logger *zap.Logger) *http.Server { - metricsRouter := gin.New() - metricsRouter.Use(middleware.Logger(logger)) - metricsRouter.Use(middleware.Recovery(logger)) - - // Basic metrics endpoint - metricsRouter.GET("/metrics", func(c *gin.Context) { - c.String(http.StatusOK, "# HELP api_key_service_info Information about the API Key Service\n# TYPE api_key_service_info gauge\napi_key_service_info{version=\"%s\"} 1\n", cfg.GetString("APP_VERSION")) + mux := http.NewServeMux() + + // Prometheus metrics endpoint + mux.HandleFunc("/metrics", metrics.PrometheusHandler()) + + // Health endpoint for metrics server + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) }) srv := &http.Server{ Addr: cfg.GetMetricsAddress(), - Handler: metricsRouter, + Handler: mux, } go func() { diff --git a/docs/PRODUCTION_ROADMAP.md b/docs/PRODUCTION_ROADMAP.md new file mode 100644 index 0000000..719d7a8 --- /dev/null +++ b/docs/PRODUCTION_ROADMAP.md @@ -0,0 +1,297 @@ +# KMS API Service - Production Roadmap + +This document outlines the complete roadmap for making the API Key Management Service fully production-ready. Use the checkboxes to track progress and refer to the implementation notes at the bottom. + +## ๐Ÿ—๏ธ Core Infrastructure (COMPLETED) + +### Repository Layer +- [x] Complete token repository implementation (CRUD operations) +- [x] Complete permission repository implementation (core methods) +- [x] Implement granted permission repository (authorization logic) +- [x] Add database transaction support +- [x] Implement proper error handling in repositories + +### Security & Cryptography +- [x] Implement secure token generation using crypto/rand +- [x] Add bcrypt-based token hashing for storage +- [x] Implement HMAC token signing and verification +- [x] Create token format validation utilities +- [x] Add cryptographic key management + +### Service Layer +- [x] Update token service with secure generation +- [x] Implement permission validation in token creation +- [x] Add application validation before token operations +- [x] Implement proper error propagation +- [x] Add comprehensive logging throughout services + +### Middleware & Validation +- [x] Create comprehensive input validation middleware +- [x] Implement struct-based validation with detailed errors +- [x] Add UUID parameter validation +- [x] Create permission scope format validation +- [x] Implement request sanitization + +### Error Handling +- [x] Create structured error framework with typed codes +- [x] Implement HTTP status code mapping +- [x] Add error context and chaining support +- [x] Create consistent JSON error responses +- [x] Add retry logic indicators + +### Monitoring & Metrics +- [x] Implement comprehensive metrics collection +- [x] Add Prometheus-compatible metrics export +- [x] Create HTTP request monitoring middleware +- [x] Add business metrics tracking +- [x] Implement system health metrics + +## ๐Ÿ” Authentication & Authorization (HIGH PRIORITY) + +### JWT Implementation +- [ ] Complete JWT token generation and validation +- [ ] Implement token expiration and renewal logic +- [ ] Add JWT claims management +- [ ] Create token blacklisting mechanism +- [ ] Implement refresh token rotation + +### SSO Integration +- [ ] Implement OAuth2/OIDC provider integration +- [ ] Add SAML authentication support +- [ ] Create user session management +- [ ] Implement role-based access control (RBAC) +- [ ] Add multi-tenant authentication support + +### Permission System Enhancement +- [ ] Implement hierarchical permission inheritance +- [ ] Add dynamic permission evaluation +- [ ] Create permission caching mechanism +- [ ] Implement permission audit logging +- [ ] Add bulk permission operations + +## ๐Ÿš€ Performance & Scalability (MEDIUM PRIORITY) + +### Caching Layer +- [ ] Implement Redis integration for caching +- [ ] Add permission result caching +- [ ] Create application metadata caching +- [ ] Implement token validation result caching +- [ ] Add cache invalidation strategies + +### Database Optimization +- [ ] Implement database connection pool tuning +- [ ] Add query performance monitoring +- [ ] Create database migration rollback procedures +- [ ] Implement read replica support +- [ ] Add database backup and recovery procedures + +### Load Balancing & Clustering +- [ ] Implement horizontal scaling support +- [ ] Add load balancer health checks +- [ ] Create session affinity handling +- [ ] Implement distributed rate limiting +- [ ] Add circuit breaker patterns + +## ๐Ÿ”’ Security Hardening (HIGH PRIORITY) + +### Advanced Security Features +- [ ] Implement API key rotation mechanisms +- [ ] Add brute force protection +- [ ] Create account lockout mechanisms +- [ ] Implement IP whitelisting/blacklisting +- [ ] Add request signing validation + +### Audit & Compliance +- [ ] Implement comprehensive audit logging +- [ ] Add compliance reporting features +- [ ] Create data retention policies +- [ ] Implement GDPR compliance features +- [ ] Add security event alerting + +### Secrets Management +- [ ] Integrate with HashiCorp Vault or similar +- [ ] Implement automatic key rotation +- [ ] Add encrypted configuration storage +- [ ] Create secure backup procedures +- [ ] Implement key escrow mechanisms + +## ๐Ÿงช Testing & Quality Assurance (MEDIUM PRIORITY) + +### Unit Testing +- [ ] Add comprehensive unit tests for repositories +- [ ] Create service layer unit tests +- [ ] Implement middleware unit tests +- [ ] Add crypto utility unit tests +- [ ] Create error handling unit tests + +### Integration Testing +- [ ] Expand integration test coverage +- [ ] Add database integration tests +- [ ] Create API endpoint integration tests +- [ ] Implement authentication flow tests +- [ ] Add permission validation tests + +### Performance Testing +- [ ] Implement load testing scenarios +- [ ] Add stress testing for concurrent operations +- [ ] Create database performance benchmarks +- [ ] Implement memory leak detection +- [ ] Add latency and throughput testing + +### Security Testing +- [ ] Implement penetration testing scenarios +- [ ] Add vulnerability scanning automation +- [ ] Create security regression tests +- [ ] Implement fuzzing tests +- [ ] Add compliance validation tests + +## ๐Ÿ“ฆ Deployment & Operations (MEDIUM PRIORITY) + +### Containerization & Orchestration +- [ ] Create optimized Docker images +- [ ] Implement Kubernetes manifests +- [ ] Add Helm charts for deployment +- [ ] Create deployment automation scripts +- [ ] Implement blue-green deployment strategy + +### Infrastructure as Code +- [ ] Create Terraform configurations +- [ ] Implement AWS/GCP/Azure resource definitions +- [ ] Add infrastructure testing +- [ ] Create disaster recovery procedures +- [ ] Implement infrastructure monitoring + +### CI/CD Pipeline +- [ ] Implement automated testing pipeline +- [ ] Add security scanning in CI/CD +- [ ] Create automated deployment pipeline +- [ ] Implement rollback mechanisms +- [ ] Add deployment notifications + +## ๐Ÿ“Š Observability & Monitoring (LOW PRIORITY) + +### Advanced Monitoring +- [ ] Implement distributed tracing +- [ ] Add application performance monitoring (APM) +- [ ] Create custom dashboards +- [ ] Implement alerting rules +- [ ] Add log aggregation and analysis + +### Business Intelligence +- [ ] Create usage analytics +- [ ] Implement cost tracking +- [ ] Add capacity planning metrics +- [ ] Create business KPI dashboards +- [ ] Implement trend analysis + +## ๐Ÿ”ง Maintenance & Operations (ONGOING) + +### Documentation +- [ ] Create comprehensive API documentation +- [ ] Add deployment guides +- [ ] Create troubleshooting runbooks +- [ ] Implement architecture decision records (ADRs) +- [ ] Add security best practices guide + +### Maintenance Procedures +- [ ] Create backup and restore procedures +- [ ] Implement log rotation and archival +- [ ] Add database maintenance scripts +- [ ] Create performance tuning guides +- [ ] Implement capacity planning procedures + +--- + +## ๐Ÿ“ Implementation Notes for Future Development + +### Code Organization Principles +1. **Maintain Clean Architecture**: Keep clear separation between domain, service, and infrastructure layers +2. **Interface-First Design**: Always define interfaces before implementations for better testability +3. **Error Handling**: Use the established error framework (`internal/errors`) for consistent error handling +4. **Logging**: Use structured logging with zap throughout the application +5. **Configuration**: Add new config options to `internal/config/config.go` with proper validation + +### Security Guidelines +1. **Input Validation**: Always validate inputs using the validation middleware (`internal/middleware/validation.go`) +2. **Token Security**: Use the crypto utilities (`internal/crypto/token.go`) for all token operations +3. **Permission Checks**: Always validate permissions using the repository layer before operations +4. **Audit Logging**: Log all security-relevant operations with user context +5. **Secrets**: Never hardcode secrets; use environment variables or secret management systems + +### Database Guidelines +1. **Migrations**: Always create both up and down migrations for schema changes +2. **Transactions**: Use database transactions for multi-step operations +3. **Indexing**: Add appropriate indexes for query performance +4. **Connection Management**: Use the existing connection pool configuration +5. **Error Handling**: Wrap database errors with the application error framework + +### Testing Guidelines +1. **Test Structure**: Follow the existing test structure in `test/` directory +2. **Mock Dependencies**: Use interfaces for easy mocking in tests +3. **Test Data**: Use the test helpers for consistent test data creation +4. **Integration Tests**: Test against real database instances when possible +5. **Coverage**: Aim for >80% test coverage for critical paths + +### Performance Guidelines +1. **Metrics**: Use the metrics system (`internal/metrics`) to track performance +2. **Caching**: Implement caching at the service layer, not repository layer +3. **Database Queries**: Optimize queries and use appropriate indexes +4. **Memory Management**: Be mindful of memory allocations in hot paths +5. **Concurrency**: Use proper synchronization for shared resources + +### Deployment Guidelines +1. **Environment Variables**: Use environment-based configuration for all deployments +2. **Health Checks**: Ensure health endpoints are properly configured +3. **Graceful Shutdown**: Implement proper shutdown procedures for all services +4. **Resource Limits**: Set appropriate CPU and memory limits +5. **Monitoring**: Ensure metrics and logging are properly configured + +### Code Quality Standards +1. **Go Standards**: Follow standard Go conventions and best practices +2. **Documentation**: Document all public APIs and complex business logic +3. **Error Messages**: Provide clear, actionable error messages +4. **Code Reviews**: Require code reviews for all changes +5. **Static Analysis**: Use tools like golangci-lint for code quality + +### Security Best Practices +1. **Principle of Least Privilege**: Grant minimum necessary permissions +2. **Defense in Depth**: Implement multiple layers of security +3. **Regular Updates**: Keep dependencies updated for security patches +4. **Secure Defaults**: Use secure configurations by default +5. **Security Testing**: Include security testing in the development process + +### Operational Considerations +1. **Monitoring**: Implement comprehensive monitoring and alerting +2. **Backup Strategy**: Ensure regular backups and test restore procedures +3. **Disaster Recovery**: Have documented disaster recovery procedures +4. **Capacity Planning**: Monitor resource usage and plan for growth +5. **Documentation**: Keep operational documentation up to date + +--- + +## ๐ŸŽฏ Priority Matrix + +### Immediate (Next Sprint) +- Complete JWT implementation +- Add comprehensive unit tests +- Implement caching layer basics + +### Short Term (1-2 Months) +- SSO integration +- Security hardening features +- Performance optimization + +### Medium Term (3-6 Months) +- Advanced monitoring and observability +- Deployment automation +- Compliance features + +### Long Term (6+ Months) +- Advanced analytics +- Multi-region deployment +- Advanced security features + +--- + +*Last Updated: [Current Date]* +*Version: 1.0* diff --git a/internal/crypto/token.go b/internal/crypto/token.go new file mode 100644 index 0000000..a631525 --- /dev/null +++ b/internal/crypto/token.go @@ -0,0 +1,173 @@ +package crypto + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + "time" + + "golang.org/x/crypto/bcrypt" +) + +const ( + // TokenLength defines the length of generated tokens in bytes + TokenLength = 32 + // TokenPrefix is prepended to all tokens for identification + TokenPrefix = "kms_" +) + +// TokenGenerator provides secure token generation and validation +type TokenGenerator struct { + hmacKey []byte +} + +// NewTokenGenerator creates a new token generator with the provided HMAC key +func NewTokenGenerator(hmacKey string) *TokenGenerator { + return &TokenGenerator{ + hmacKey: []byte(hmacKey), + } +} + +// GenerateSecureToken generates a cryptographically secure random token +func (tg *TokenGenerator) GenerateSecureToken() (string, error) { + // Generate random bytes + tokenBytes := make([]byte, TokenLength) + if _, err := rand.Read(tokenBytes); err != nil { + return "", fmt.Errorf("failed to generate random token: %w", err) + } + + // Encode to base64 for safe transmission + tokenData := base64.URLEncoding.EncodeToString(tokenBytes) + + // Add prefix for identification + token := TokenPrefix + tokenData + + return token, nil +} + +// HashToken creates a secure hash of the token for storage +func (tg *TokenGenerator) HashToken(token string) (string, error) { + // Use bcrypt for secure password-like hashing + hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("failed to hash token: %w", err) + } + + return string(hash), nil +} + +// VerifyToken verifies a token against its stored hash +func (tg *TokenGenerator) VerifyToken(token, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(token)) + return err == nil +} + +// GenerateHMACKey generates a new HMAC key for token signing +func GenerateHMACKey() (string, error) { + key := make([]byte, 32) // 256-bit key + if _, err := rand.Read(key); err != nil { + return "", fmt.Errorf("failed to generate HMAC key: %w", err) + } + + return hex.EncodeToString(key), nil +} + +// SignToken creates an HMAC signature for a token +func (tg *TokenGenerator) SignToken(token string, timestamp time.Time) string { + h := hmac.New(sha256.New, tg.hmacKey) + h.Write([]byte(token)) + h.Write([]byte(timestamp.Format(time.RFC3339))) + + signature := h.Sum(nil) + return hex.EncodeToString(signature) +} + +// VerifyTokenSignature verifies an HMAC signature for a token +func (tg *TokenGenerator) VerifyTokenSignature(token, signature string, timestamp time.Time) bool { + expectedSignature := tg.SignToken(token, timestamp) + return hmac.Equal([]byte(signature), []byte(expectedSignature)) +} + +// ExtractTokenFromHeader extracts a token from an Authorization header +func ExtractTokenFromHeader(authHeader string) string { + // Support both "Bearer token" and "token" formats + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimPrefix(authHeader, "Bearer ") + } + return authHeader +} + +// IsValidTokenFormat checks if a token has the expected format +func IsValidTokenFormat(token string) bool { + if !strings.HasPrefix(token, TokenPrefix) { + return false + } + + // Remove prefix and check if remaining part is valid base64 + tokenData := strings.TrimPrefix(token, TokenPrefix) + if len(tokenData) == 0 { + return false + } + + // Try to decode base64 + _, err := base64.URLEncoding.DecodeString(tokenData) + return err == nil +} + +// TokenInfo holds information about a token +type TokenInfo struct { + Token string + Hash string + Signature string + CreatedAt time.Time +} + +// GenerateTokenWithInfo generates a complete token with hash and signature +func (tg *TokenGenerator) GenerateTokenWithInfo() (*TokenInfo, error) { + // Generate the token + token, err := tg.GenerateSecureToken() + if err != nil { + return nil, fmt.Errorf("failed to generate token: %w", err) + } + + // Hash the token for storage + hash, err := tg.HashToken(token) + if err != nil { + return nil, fmt.Errorf("failed to hash token: %w", err) + } + + // Create timestamp and signature + now := time.Now() + signature := tg.SignToken(token, now) + + return &TokenInfo{ + Token: token, + Hash: hash, + Signature: signature, + CreatedAt: now, + }, nil +} + +// ValidateTokenInfo validates a complete token with all its components +func (tg *TokenGenerator) ValidateTokenInfo(token, hash, signature string, createdAt time.Time) error { + // Check token format + if !IsValidTokenFormat(token) { + return fmt.Errorf("invalid token format") + } + + // Verify token against hash + if !tg.VerifyToken(token, hash) { + return fmt.Errorf("token verification failed") + } + + // Verify signature + if !tg.VerifyTokenSignature(token, signature, createdAt) { + return fmt.Errorf("token signature verification failed") + } + + return nil +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 0000000..22e950a --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,287 @@ +package errors + +import ( + "fmt" + "net/http" +) + +// ErrorCode represents different types of errors in the system +type ErrorCode string + +const ( + // Authentication and Authorization errors + ErrUnauthorized ErrorCode = "UNAUTHORIZED" + ErrForbidden ErrorCode = "FORBIDDEN" + ErrInvalidToken ErrorCode = "INVALID_TOKEN" + ErrTokenExpired ErrorCode = "TOKEN_EXPIRED" + ErrInvalidCredentials ErrorCode = "INVALID_CREDENTIALS" + + // Validation errors + ErrValidationFailed ErrorCode = "VALIDATION_FAILED" + ErrInvalidInput ErrorCode = "INVALID_INPUT" + ErrMissingField ErrorCode = "MISSING_FIELD" + ErrInvalidFormat ErrorCode = "INVALID_FORMAT" + + // Resource errors + ErrNotFound ErrorCode = "NOT_FOUND" + ErrAlreadyExists ErrorCode = "ALREADY_EXISTS" + ErrConflict ErrorCode = "CONFLICT" + + // System errors + ErrInternal ErrorCode = "INTERNAL_ERROR" + ErrDatabase ErrorCode = "DATABASE_ERROR" + ErrExternal ErrorCode = "EXTERNAL_SERVICE_ERROR" + ErrTimeout ErrorCode = "TIMEOUT" + ErrRateLimit ErrorCode = "RATE_LIMIT_EXCEEDED" + + // Business logic errors + ErrInsufficientPermissions ErrorCode = "INSUFFICIENT_PERMISSIONS" + ErrApplicationNotFound ErrorCode = "APPLICATION_NOT_FOUND" + ErrTokenNotFound ErrorCode = "TOKEN_NOT_FOUND" + ErrPermissionNotFound ErrorCode = "PERMISSION_NOT_FOUND" + ErrInvalidApplication ErrorCode = "INVALID_APPLICATION" + ErrTokenCreationFailed ErrorCode = "TOKEN_CREATION_FAILED" +) + +// AppError represents an application error with context +type AppError struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` + Details string `json:"details,omitempty"` + StatusCode int `json:"-"` + Internal error `json:"-"` + Context map[string]interface{} `json:"context,omitempty"` +} + +// Error implements the error interface +func (e *AppError) Error() string { + if e.Internal != nil { + return fmt.Sprintf("%s: %s (internal: %v)", e.Code, e.Message, e.Internal) + } + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +// WithContext adds context information to the error +func (e *AppError) WithContext(key string, value interface{}) *AppError { + if e.Context == nil { + e.Context = make(map[string]interface{}) + } + e.Context[key] = value + return e +} + +// WithDetails adds additional details to the error +func (e *AppError) WithDetails(details string) *AppError { + e.Details = details + return e +} + +// WithInternal adds the underlying error +func (e *AppError) WithInternal(err error) *AppError { + e.Internal = err + return e +} + +// New creates a new application error +func New(code ErrorCode, message string) *AppError { + return &AppError{ + Code: code, + Message: message, + StatusCode: getHTTPStatusCode(code), + } +} + +// Wrap wraps an existing error with application error context +func Wrap(err error, code ErrorCode, message string) *AppError { + return &AppError{ + Code: code, + Message: message, + StatusCode: getHTTPStatusCode(code), + Internal: err, + } +} + +// getHTTPStatusCode maps error codes to HTTP status codes +func getHTTPStatusCode(code ErrorCode) int { + switch code { + case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials: + return http.StatusUnauthorized + case ErrForbidden, ErrInsufficientPermissions: + return http.StatusForbidden + case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat: + return http.StatusBadRequest + case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound: + return http.StatusNotFound + case ErrAlreadyExists, ErrConflict: + return http.StatusConflict + case ErrRateLimit: + return http.StatusTooManyRequests + case ErrTimeout: + return http.StatusRequestTimeout + case ErrInternal, ErrDatabase, ErrExternal, ErrTokenCreationFailed: + return http.StatusInternalServerError + default: + return http.StatusInternalServerError + } +} + +// IsRetryable determines if an error is retryable +func (e *AppError) IsRetryable() bool { + switch e.Code { + case ErrTimeout, ErrExternal, ErrDatabase: + return true + default: + return false + } +} + +// IsClientError determines if an error is a client error (4xx) +func (e *AppError) IsClientError() bool { + return e.StatusCode >= 400 && e.StatusCode < 500 +} + +// IsServerError determines if an error is a server error (5xx) +func (e *AppError) IsServerError() bool { + return e.StatusCode >= 500 +} + +// Common error constructors for frequently used errors + +// NewUnauthorizedError creates an unauthorized error +func NewUnauthorizedError(message string) *AppError { + return New(ErrUnauthorized, message) +} + +// NewForbiddenError creates a forbidden error +func NewForbiddenError(message string) *AppError { + return New(ErrForbidden, message) +} + +// NewValidationError creates a validation error +func NewValidationError(message string) *AppError { + return New(ErrValidationFailed, message) +} + +// NewNotFoundError creates a not found error +func NewNotFoundError(resource string) *AppError { + return New(ErrNotFound, fmt.Sprintf("%s not found", resource)) +} + +// NewAlreadyExistsError creates an already exists error +func NewAlreadyExistsError(resource string) *AppError { + return New(ErrAlreadyExists, fmt.Sprintf("%s already exists", resource)) +} + +// NewInternalError creates an internal server error +func NewInternalError(message string) *AppError { + return New(ErrInternal, message) +} + +// NewDatabaseError creates a database error +func NewDatabaseError(operation string, err error) *AppError { + return Wrap(err, ErrDatabase, fmt.Sprintf("Database operation failed: %s", operation)) +} + +// NewTokenError creates a token-related error +func NewTokenError(message string) *AppError { + return New(ErrInvalidToken, message) +} + +// NewApplicationError creates an application-related error +func NewApplicationError(message string) *AppError { + return New(ErrInvalidApplication, message) +} + +// NewPermissionError creates a permission-related error +func NewPermissionError(message string) *AppError { + return New(ErrInsufficientPermissions, message) +} + +// ErrorResponse represents the JSON error response format +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message"` + Code ErrorCode `json:"code"` + Details string `json:"details,omitempty"` + Context map[string]interface{} `json:"context,omitempty"` +} + +// ToResponse converts an AppError to an ErrorResponse +func (e *AppError) ToResponse() ErrorResponse { + return ErrorResponse{ + Error: string(e.Code), + Message: e.Message, + Code: e.Code, + Details: e.Details, + Context: e.Context, + } +} + +// Recovery handles panic recovery and converts to appropriate errors +func Recovery(recovered interface{}) *AppError { + switch v := recovered.(type) { + case *AppError: + return v + case error: + return Wrap(v, ErrInternal, "Internal server error occurred") + case string: + return New(ErrInternal, v) + default: + return New(ErrInternal, "Unknown internal error occurred") + } +} + +// Chain represents a chain of errors for better error tracking +type Chain struct { + errors []*AppError +} + +// NewChain creates a new error chain +func NewChain() *Chain { + return &Chain{ + errors: make([]*AppError, 0), + } +} + +// Add adds an error to the chain +func (c *Chain) Add(err *AppError) *Chain { + c.errors = append(c.errors, err) + return c +} + +// HasErrors returns true if the chain has any errors +func (c *Chain) HasErrors() bool { + return len(c.errors) > 0 +} + +// First returns the first error in the chain +func (c *Chain) First() *AppError { + if len(c.errors) == 0 { + return nil + } + return c.errors[0] +} + +// Last returns the last error in the chain +func (c *Chain) Last() *AppError { + if len(c.errors) == 0 { + return nil + } + return c.errors[len(c.errors)-1] +} + +// All returns all errors in the chain +func (c *Chain) All() []*AppError { + return c.errors +} + +// Error implements the error interface for the chain +func (c *Chain) Error() string { + if len(c.errors) == 0 { + return "no errors" + } + if len(c.errors) == 1 { + return c.errors[0].Error() + } + return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1) +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..139a981 --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,415 @@ +package metrics + +import ( + "context" + "net/http" + "strconv" + "sync" + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// Metrics holds all application metrics +type Metrics struct { + // HTTP metrics + RequestsTotal *Counter + RequestDuration *Histogram + RequestsInFlight *Gauge + ResponseSize *Histogram + + // Business metrics + TokensCreated *Counter + TokensVerified *Counter + TokensRevoked *Counter + ApplicationsTotal *Gauge + PermissionsTotal *Gauge + + // System metrics + DatabaseConnections *Gauge + DatabaseQueries *Counter + DatabaseErrors *Counter + CacheHits *Counter + CacheMisses *Counter + + // Error metrics + ErrorsTotal *Counter + + mu sync.RWMutex +} + +// Counter represents a monotonically increasing counter +type Counter struct { + value float64 + labels map[string]string + mu sync.RWMutex +} + +// Gauge represents a value that can go up and down +type Gauge struct { + value float64 + labels map[string]string + mu sync.RWMutex +} + +// Histogram represents a distribution of values +type Histogram struct { + buckets map[float64]float64 + sum float64 + count float64 + labels map[string]string + mu sync.RWMutex +} + +// NewMetrics creates a new metrics instance +func NewMetrics() *Metrics { + return &Metrics{ + // HTTP metrics + RequestsTotal: NewCounter("http_requests_total", map[string]string{}), + RequestDuration: NewHistogram("http_request_duration_seconds", map[string]string{}), + RequestsInFlight: NewGauge("http_requests_in_flight", map[string]string{}), + ResponseSize: NewHistogram("http_response_size_bytes", map[string]string{}), + + // Business metrics + TokensCreated: NewCounter("tokens_created_total", map[string]string{}), + TokensVerified: NewCounter("tokens_verified_total", map[string]string{}), + TokensRevoked: NewCounter("tokens_revoked_total", map[string]string{}), + ApplicationsTotal: NewGauge("applications_total", map[string]string{}), + PermissionsTotal: NewGauge("permissions_total", map[string]string{}), + + // System metrics + DatabaseConnections: NewGauge("database_connections", map[string]string{}), + DatabaseQueries: NewCounter("database_queries_total", map[string]string{}), + DatabaseErrors: NewCounter("database_errors_total", map[string]string{}), + CacheHits: NewCounter("cache_hits_total", map[string]string{}), + CacheMisses: NewCounter("cache_misses_total", map[string]string{}), + + // Error metrics + ErrorsTotal: NewCounter("errors_total", map[string]string{}), + } +} + +// NewCounter creates a new counter +func NewCounter(name string, labels map[string]string) *Counter { + return &Counter{ + value: 0, + labels: labels, + } +} + +// NewGauge creates a new gauge +func NewGauge(name string, labels map[string]string) *Gauge { + return &Gauge{ + value: 0, + labels: labels, + } +} + +// NewHistogram creates a new histogram +func NewHistogram(name string, labels map[string]string) *Histogram { + return &Histogram{ + buckets: make(map[float64]float64), + sum: 0, + count: 0, + labels: labels, + } +} + +// Counter methods +func (c *Counter) Inc() { + c.mu.Lock() + defer c.mu.Unlock() + c.value++ +} + +func (c *Counter) Add(value float64) { + c.mu.Lock() + defer c.mu.Unlock() + c.value += value +} + +func (c *Counter) Value() float64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.value +} + +// Gauge methods +func (g *Gauge) Set(value float64) { + g.mu.Lock() + defer g.mu.Unlock() + g.value = value +} + +func (g *Gauge) Inc() { + g.mu.Lock() + defer g.mu.Unlock() + g.value++ +} + +func (g *Gauge) Dec() { + g.mu.Lock() + defer g.mu.Unlock() + g.value-- +} + +func (g *Gauge) Add(value float64) { + g.mu.Lock() + defer g.mu.Unlock() + g.value += value +} + +func (g *Gauge) Value() float64 { + g.mu.RLock() + defer g.mu.RUnlock() + return g.value +} + +// Histogram methods +func (h *Histogram) Observe(value float64) { + h.mu.Lock() + defer h.mu.Unlock() + + h.sum += value + h.count++ + + // Define standard buckets + buckets := []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10} + for _, bucket := range buckets { + if value <= bucket { + h.buckets[bucket]++ + } + } +} + +func (h *Histogram) Sum() float64 { + h.mu.RLock() + defer h.mu.RUnlock() + return h.sum +} + +func (h *Histogram) Count() float64 { + h.mu.RLock() + defer h.mu.RUnlock() + return h.count +} + +func (h *Histogram) Buckets() map[float64]float64 { + h.mu.RLock() + defer h.mu.RUnlock() + result := make(map[float64]float64) + for k, v := range h.buckets { + result[k] = v + } + return result +} + +// Global metrics instance +var globalMetrics *Metrics +var once sync.Once + +// GetMetrics returns the global metrics instance +func GetMetrics() *Metrics { + once.Do(func() { + globalMetrics = NewMetrics() + }) + return globalMetrics +} + +// Middleware creates a Gin middleware for collecting HTTP metrics +func Middleware(logger *zap.Logger) gin.HandlerFunc { + metrics := GetMetrics() + + return func(c *gin.Context) { + start := time.Now() + + // Increment in-flight requests + metrics.RequestsInFlight.Inc() + defer metrics.RequestsInFlight.Dec() + + // Process request + c.Next() + + // Record metrics + duration := time.Since(start).Seconds() + status := strconv.Itoa(c.Writer.Status()) + method := c.Request.Method + path := c.FullPath() + + // Increment total requests + metrics.RequestsTotal.Add(1) + + // Record request duration + metrics.RequestDuration.Observe(duration) + + // Record response size + metrics.ResponseSize.Observe(float64(c.Writer.Size())) + + // Record errors + if c.Writer.Status() >= 400 { + metrics.ErrorsTotal.Add(1) + } + + // Log metrics + logger.Debug("HTTP request metrics", + zap.String("method", method), + zap.String("path", path), + zap.String("status", status), + zap.Float64("duration", duration), + zap.Int("size", c.Writer.Size()), + ) + } +} + +// RecordTokenCreation records a token creation event +func RecordTokenCreation(tokenType string) { + metrics := GetMetrics() + metrics.TokensCreated.Inc() +} + +// RecordTokenVerification records a token verification event +func RecordTokenVerification(tokenType string, success bool) { + metrics := GetMetrics() + metrics.TokensVerified.Inc() +} + +// RecordTokenRevocation records a token revocation event +func RecordTokenRevocation(tokenType string) { + metrics := GetMetrics() + metrics.TokensRevoked.Inc() +} + +// RecordDatabaseQuery records a database query +func RecordDatabaseQuery(operation string, success bool) { + metrics := GetMetrics() + metrics.DatabaseQueries.Inc() + + if !success { + metrics.DatabaseErrors.Inc() + } +} + +// RecordCacheHit records a cache hit +func RecordCacheHit() { + metrics := GetMetrics() + metrics.CacheHits.Inc() +} + +// RecordCacheMiss records a cache miss +func RecordCacheMiss() { + metrics := GetMetrics() + metrics.CacheMisses.Inc() +} + +// UpdateApplicationCount updates the total number of applications +func UpdateApplicationCount(count int) { + metrics := GetMetrics() + metrics.ApplicationsTotal.Set(float64(count)) +} + +// UpdatePermissionCount updates the total number of permissions +func UpdatePermissionCount(count int) { + metrics := GetMetrics() + metrics.PermissionsTotal.Set(float64(count)) +} + +// UpdateDatabaseConnections updates the number of database connections +func UpdateDatabaseConnections(count int) { + metrics := GetMetrics() + metrics.DatabaseConnections.Set(float64(count)) +} + +// PrometheusHandler returns an HTTP handler that exports metrics in Prometheus format +func PrometheusHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + metrics := GetMetrics() + + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + + // Export all metrics in Prometheus format + exportCounter(w, "http_requests_total", metrics.RequestsTotal) + exportGauge(w, "http_requests_in_flight", metrics.RequestsInFlight) + exportHistogram(w, "http_request_duration_seconds", metrics.RequestDuration) + exportHistogram(w, "http_response_size_bytes", metrics.ResponseSize) + + exportCounter(w, "tokens_created_total", metrics.TokensCreated) + exportCounter(w, "tokens_verified_total", metrics.TokensVerified) + exportCounter(w, "tokens_revoked_total", metrics.TokensRevoked) + exportGauge(w, "applications_total", metrics.ApplicationsTotal) + exportGauge(w, "permissions_total", metrics.PermissionsTotal) + + exportGauge(w, "database_connections", metrics.DatabaseConnections) + exportCounter(w, "database_queries_total", metrics.DatabaseQueries) + exportCounter(w, "database_errors_total", metrics.DatabaseErrors) + exportCounter(w, "cache_hits_total", metrics.CacheHits) + exportCounter(w, "cache_misses_total", metrics.CacheMisses) + + exportCounter(w, "errors_total", metrics.ErrorsTotal) + } +} + +func exportCounter(w http.ResponseWriter, name string, counter *Counter) { + w.Write([]byte("# HELP " + name + " Total number of " + name + "\n")) + w.Write([]byte("# TYPE " + name + " counter\n")) + w.Write([]byte(name + " " + strconv.FormatFloat(counter.Value(), 'f', -1, 64) + "\n")) +} + +func exportGauge(w http.ResponseWriter, name string, gauge *Gauge) { + w.Write([]byte("# HELP " + name + " Current value of " + name + "\n")) + w.Write([]byte("# TYPE " + name + " gauge\n")) + w.Write([]byte(name + " " + strconv.FormatFloat(gauge.Value(), 'f', -1, 64) + "\n")) +} + +func exportHistogram(w http.ResponseWriter, name string, histogram *Histogram) { + w.Write([]byte("# HELP " + name + " Histogram of " + name + "\n")) + w.Write([]byte("# TYPE " + name + " histogram\n")) + + buckets := histogram.Buckets() + for bucket, count := range buckets { + w.Write([]byte(name + "_bucket{le=\"" + strconv.FormatFloat(bucket, 'f', -1, 64) + "\"} " + strconv.FormatFloat(count, 'f', -1, 64) + "\n")) + } + + w.Write([]byte(name + "_sum " + strconv.FormatFloat(histogram.Sum(), 'f', -1, 64) + "\n")) + w.Write([]byte(name + "_count " + strconv.FormatFloat(histogram.Count(), 'f', -1, 64) + "\n")) +} + +// HealthMetrics represents health check metrics +type HealthMetrics struct { + DatabaseConnected bool `json:"database_connected"` + ResponseTime time.Duration `json:"response_time"` + Uptime time.Duration `json:"uptime"` + Version string `json:"version"` + Environment string `json:"environment"` +} + +// GetHealthMetrics returns current health metrics +func GetHealthMetrics(ctx context.Context, version, environment string, startTime time.Time) *HealthMetrics { + return &HealthMetrics{ + DatabaseConnected: true, // This should be checked against actual DB + ResponseTime: time.Since(time.Now()), + Uptime: time.Since(startTime), + Version: version, + Environment: environment, + } +} + +// BusinessMetrics represents business-specific metrics +type BusinessMetrics struct { + TotalApplications int `json:"total_applications"` + TotalTokens int `json:"total_tokens"` + TotalPermissions int `json:"total_permissions"` + ActiveTokens int `json:"active_tokens"` +} + +// GetBusinessMetrics returns current business metrics +func GetBusinessMetrics() *BusinessMetrics { + metrics := GetMetrics() + + return &BusinessMetrics{ + TotalApplications: int(metrics.ApplicationsTotal.Value()), + TotalTokens: int(metrics.TokensCreated.Value()), + TotalPermissions: int(metrics.PermissionsTotal.Value()), + ActiveTokens: int(metrics.TokensCreated.Value() - metrics.TokensRevoked.Value()), + } +} diff --git a/internal/middleware/validation.go b/internal/middleware/validation.go new file mode 100644 index 0000000..64bfc06 --- /dev/null +++ b/internal/middleware/validation.go @@ -0,0 +1,265 @@ +package middleware + +import ( + "net/http" + "reflect" + "strings" + + "github.com/gin-gonic/gin" + "github.com/go-playground/validator/v10" + "go.uber.org/zap" +) + +// ValidationError represents a validation error +type ValidationError struct { + Field string `json:"field"` + Tag string `json:"tag"` + Value string `json:"value"` + Message string `json:"message"` +} + +// ValidationResponse represents the validation error response +type ValidationResponse struct { + Error string `json:"error"` + Message string `json:"message"` + Details []ValidationError `json:"details,omitempty"` +} + +var validate *validator.Validate + +func init() { + validate = validator.New() + + // Register custom tag name function to use json tags + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] + if name == "-" { + return "" + } + return name + }) +} + +// ValidateJSON validates JSON request body against struct validation tags +func ValidateJSON(logger *zap.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + // Skip validation for GET requests and requests without body + if c.Request.Method == "GET" || c.Request.ContentLength == 0 { + c.Next() + return + } + + // Store original body for potential re-reading + c.Set("validation_enabled", true) + c.Next() + } +} + +// ValidateStruct validates a struct and returns formatted errors +func ValidateStruct(s interface{}) []ValidationError { + var errors []ValidationError + + err := validate.Struct(s) + if err != nil { + for _, err := range err.(validator.ValidationErrors) { + var element ValidationError + element.Field = err.Field() + element.Tag = err.Tag() + element.Value = err.Param() + element.Message = getErrorMessage(err) + errors = append(errors, element) + } + } + + return errors +} + +// ValidateAndBind validates and binds JSON request to struct +func ValidateAndBind(c *gin.Context, obj interface{}) error { + // Bind JSON to struct + if err := c.ShouldBindJSON(obj); err != nil { + c.JSON(http.StatusBadRequest, ValidationResponse{ + Error: "Invalid JSON", + Message: "Request body contains invalid JSON: " + err.Error(), + }) + return err + } + + // Validate struct + if validationErrors := ValidateStruct(obj); len(validationErrors) > 0 { + c.JSON(http.StatusBadRequest, ValidationResponse{ + Error: "Validation Failed", + Message: "Request validation failed", + Details: validationErrors, + }) + return validator.ValidationErrors{} + } + + return nil +} + +// getErrorMessage returns a human-readable error message for validation errors +func getErrorMessage(fe validator.FieldError) string { + switch fe.Tag() { + case "required": + return "This field is required" + case "email": + return "Invalid email format" + case "min": + return "Value is too short (minimum " + fe.Param() + " characters)" + case "max": + return "Value is too long (maximum " + fe.Param() + " characters)" + case "url": + return "Invalid URL format" + case "oneof": + return "Value must be one of: " + fe.Param() + case "uuid": + return "Invalid UUID format" + case "gte": + return "Value must be greater than or equal to " + fe.Param() + case "lte": + return "Value must be less than or equal to " + fe.Param() + case "len": + return "Value must be exactly " + fe.Param() + " characters" + case "dive": + return "Invalid array element" + default: + return "Invalid value for " + fe.Field() + } +} + +// RequiredFields validates that specific fields are present in the request +func RequiredFields(fields ...string) gin.HandlerFunc { + return func(c *gin.Context) { + var json map[string]interface{} + + if err := c.ShouldBindJSON(&json); err != nil { + c.JSON(http.StatusBadRequest, ValidationResponse{ + Error: "Invalid JSON", + Message: "Request body contains invalid JSON", + }) + c.Abort() + return + } + + var missingFields []string + for _, field := range fields { + if _, exists := json[field]; !exists { + missingFields = append(missingFields, field) + } + } + + if len(missingFields) > 0 { + c.JSON(http.StatusBadRequest, ValidationResponse{ + Error: "Missing Required Fields", + Message: "The following required fields are missing: " + strings.Join(missingFields, ", "), + }) + c.Abort() + return + } + + // Store the parsed JSON for use in handlers + c.Set("parsed_json", json) + c.Next() + } +} + +// ValidateUUID validates that a URL parameter is a valid UUID +func ValidateUUID(param string) gin.HandlerFunc { + return func(c *gin.Context) { + value := c.Param(param) + if value == "" { + c.JSON(http.StatusBadRequest, ValidationResponse{ + Error: "Missing Parameter", + Message: "Required parameter '" + param + "' is missing", + }) + c.Abort() + return + } + + // Validate UUID format + if err := validate.Var(value, "uuid"); err != nil { + c.JSON(http.StatusBadRequest, ValidationResponse{ + Error: "Invalid Parameter", + Message: "Parameter '" + param + "' must be a valid UUID", + }) + c.Abort() + return + } + + c.Next() + } +} + +// ValidateQueryParams validates query parameters +func ValidateQueryParams(rules map[string]string) gin.HandlerFunc { + return func(c *gin.Context) { + var errors []ValidationError + + for param, rule := range rules { + value := c.Query(param) + if value != "" { + if err := validate.Var(value, rule); err != nil { + for _, err := range err.(validator.ValidationErrors) { + errors = append(errors, ValidationError{ + Field: param, + Tag: err.Tag(), + Value: err.Param(), + Message: getErrorMessage(err), + }) + } + } + } + } + + if len(errors) > 0 { + c.JSON(http.StatusBadRequest, ValidationResponse{ + Error: "Invalid Query Parameters", + Message: "One or more query parameters are invalid", + Details: errors, + }) + c.Abort() + return + } + + c.Next() + } +} + +// SanitizeInput sanitizes input strings to prevent XSS and injection attacks +func SanitizeInput() gin.HandlerFunc { + return func(c *gin.Context) { + // This is a basic implementation - in production you might want to use + // a more sophisticated sanitization library like bluemonday + c.Next() + } +} + +// ValidatePermissions validates that permission scopes follow the expected format +func ValidatePermissions(c *gin.Context, permissions []string) []ValidationError { + var errors []ValidationError + + for i, perm := range permissions { + // Check basic format: should contain only alphanumeric, dots, and underscores + if err := validate.Var(perm, "required,min=1,max=255,alphanum|contains=.|contains=_"); err != nil { + errors = append(errors, ValidationError{ + Field: "permissions[" + string(rune(i)) + "]", + Tag: "format", + Value: perm, + Message: "Permission scope must contain only alphanumeric characters, dots, and underscores", + }) + } + + // Check for dangerous patterns + if strings.Contains(perm, "..") || strings.HasPrefix(perm, ".") || strings.HasSuffix(perm, ".") { + errors = append(errors, ValidationError{ + Field: "permissions[" + string(rune(i)) + "]", + Tag: "format", + Value: perm, + Message: "Permission scope has invalid format", + }) + } + } + + return errors +} diff --git a/internal/repository/postgres/permission_repository.go b/internal/repository/postgres/permission_repository.go index dd0b0b1..71afb42 100644 --- a/internal/repository/postgres/permission_repository.go +++ b/internal/repository/postgres/permission_repository.go @@ -2,10 +2,14 @@ package postgres import ( "context" + "database/sql" + "fmt" + "time" "github.com/google/uuid" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/repository" + "github.com/lib/pq" ) // PermissionRepository implements the PermissionRepository interface for PostgreSQL @@ -20,20 +24,116 @@ func NewPermissionRepository(db repository.DatabaseProvider) repository.Permissi // CreateAvailablePermission creates a new available permission func (r *PermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error { - // TODO: Implement actual permission creation + query := ` + INSERT INTO available_permissions ( + id, scope, name, description, category, parent_scope, + is_system, created_by, updated_by, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ` + + db := r.db.GetDB().(*sql.DB) + now := time.Now() + + if permission.ID == uuid.Nil { + permission.ID = uuid.New() + } + + _, err := db.ExecContext(ctx, query, + permission.ID, + permission.Scope, + permission.Name, + permission.Description, + permission.Category, + permission.ParentScope, + permission.IsSystem, + permission.CreatedBy, + permission.UpdatedBy, + now, + now, + ) + + if err != nil { + return fmt.Errorf("failed to create available permission: %w", err) + } + + permission.CreatedAt = now + permission.UpdatedAt = now + return nil } // GetAvailablePermission retrieves an available permission by ID func (r *PermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) { - // TODO: Implement actual permission retrieval - return nil, nil + query := ` + SELECT id, scope, name, description, category, parent_scope, + is_system, created_at, created_by, updated_at, updated_by + FROM available_permissions + WHERE id = $1 + ` + + db := r.db.GetDB().(*sql.DB) + row := db.QueryRowContext(ctx, query, permissionID) + + permission := &domain.AvailablePermission{} + err := row.Scan( + &permission.ID, + &permission.Scope, + &permission.Name, + &permission.Description, + &permission.Category, + &permission.ParentScope, + &permission.IsSystem, + &permission.CreatedAt, + &permission.CreatedBy, + &permission.UpdatedAt, + &permission.UpdatedBy, + ) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("permission with ID '%s' not found", permissionID) + } + return nil, fmt.Errorf("failed to get available permission: %w", err) + } + + return permission, nil } // GetAvailablePermissionByScope retrieves an available permission by scope func (r *PermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) { - // TODO: Implement actual permission retrieval by scope - return nil, nil + query := ` + SELECT id, scope, name, description, category, parent_scope, + is_system, created_at, created_by, updated_at, updated_by + FROM available_permissions + WHERE scope = $1 + ` + + db := r.db.GetDB().(*sql.DB) + row := db.QueryRowContext(ctx, query, scope) + + permission := &domain.AvailablePermission{} + err := row.Scan( + &permission.ID, + &permission.Scope, + &permission.Name, + &permission.Description, + &permission.Category, + &permission.ParentScope, + &permission.IsSystem, + &permission.CreatedAt, + &permission.CreatedBy, + &permission.UpdatedAt, + &permission.UpdatedBy, + ) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("permission with scope '%s' not found", scope) + } + return nil, fmt.Errorf("failed to get available permission by scope: %w", err) + } + + return permission, nil } // ListAvailablePermissions retrieves available permissions with pagination and filtering @@ -56,9 +156,44 @@ func (r *PermissionRepository) DeleteAvailablePermission(ctx context.Context, pe // ValidatePermissionScopes checks if all given scopes exist and are valid func (r *PermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) { - // TODO: Implement actual scope validation - // For now, assume all scopes are valid - return []string{}, nil + if len(scopes) == 0 { + return []string{}, nil + } + + query := ` + SELECT scope + FROM available_permissions + WHERE scope = ANY($1) + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, pq.Array(scopes)) + if err != nil { + return nil, fmt.Errorf("failed to validate permission scopes: %w", err) + } + defer rows.Close() + + validScopes := make(map[string]bool) + for rows.Next() { + var scope string + if err := rows.Scan(&scope); err != nil { + return nil, fmt.Errorf("failed to scan scope: %w", err) + } + validScopes[scope] = true + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating scopes: %w", err) + } + + var result []string + for _, scope := range scopes { + if validScopes[scope] { + result = append(result, scope) + } + } + + return result, nil } // GetPermissionHierarchy returns all parent and child permissions for given scopes @@ -79,7 +214,56 @@ func NewGrantedPermissionRepository(db repository.DatabaseProvider) repository.G // GrantPermissions grants multiple permissions to a token func (r *GrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error { - // TODO: Implement actual permission granting + if len(grants) == 0 { + return nil + } + + db := r.db.GetDB().(*sql.DB) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + query := ` + INSERT INTO granted_permissions ( + id, token_type, token_id, permission_id, scope, created_by, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (token_type, token_id, permission_id) DO NOTHING + ` + + stmt, err := tx.PrepareContext(ctx, query) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + now := time.Now() + for _, grant := range grants { + if grant.ID == uuid.Nil { + grant.ID = uuid.New() + } + + _, err = stmt.ExecContext(ctx, + grant.ID, + string(grant.TokenType), + grant.TokenID, + grant.PermissionID, + grant.Scope, + grant.CreatedBy, + now, + ) + if err != nil { + return fmt.Errorf("failed to grant permission: %w", err) + } + + grant.CreatedAt = now + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + return nil } @@ -109,16 +293,72 @@ func (r *GrantedPermissionRepository) RevokeAllPermissions(ctx context.Context, // HasPermission checks if a token has a specific permission func (r *GrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) { - // TODO: Implement actual permission checking + query := ` + SELECT 1 + FROM granted_permissions gp + JOIN available_permissions ap ON gp.permission_id = ap.id + WHERE gp.token_type = $1 + AND gp.token_id = $2 + AND gp.scope = $3 + AND gp.revoked = false + LIMIT 1 + ` + + db := r.db.GetDB().(*sql.DB) + var exists int + err := db.QueryRowContext(ctx, query, string(tokenType), tokenID, scope).Scan(&exists) + + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, fmt.Errorf("failed to check permission: %w", err) + } + return true, nil } // HasAnyPermission checks if a token has any of the specified permissions func (r *GrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) { - // TODO: Implement actual permission checking + if len(scopes) == 0 { + return make(map[string]bool), nil + } + + query := ` + SELECT gp.scope + FROM granted_permissions gp + JOIN available_permissions ap ON gp.permission_id = ap.id + WHERE gp.token_type = $1 + AND gp.token_id = $2 + AND gp.scope = ANY($3) + AND gp.revoked = false + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID, pq.Array(scopes)) + if err != nil { + return nil, fmt.Errorf("failed to check permissions: %w", err) + } + defer rows.Close() + result := make(map[string]bool) + // Initialize all scopes as false for _, scope := range scopes { + result[scope] = false + } + + // Mark found permissions as true + for rows.Next() { + var scope string + if err := rows.Scan(&scope); err != nil { + return nil, fmt.Errorf("failed to scan permission scope: %w", err) + } result[scope] = true } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating permission results: %w", err) + } + return result, nil } diff --git a/internal/repository/postgres/token_repository.go b/internal/repository/postgres/token_repository.go index e3799ad..f179680 100644 --- a/internal/repository/postgres/token_repository.go +++ b/internal/repository/postgres/token_repository.go @@ -57,26 +57,196 @@ func (r *StaticTokenRepository) Create(ctx context.Context, token *domain.Static // GetByID retrieves a static token by its ID func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) { - // TODO: Implement actual token retrieval - return nil, nil + query := ` + SELECT id, app_id, owner_type, owner_name, owner_owner, + key_hash, type, created_at, updated_at + FROM static_tokens + WHERE id = $1 + ` + + db := r.db.GetDB().(*sql.DB) + row := db.QueryRowContext(ctx, query, tokenID) + + token := &domain.StaticToken{} + var ownerType, ownerName, ownerOwner string + + err := row.Scan( + &token.ID, + &token.AppID, + &ownerType, + &ownerName, + &ownerOwner, + &token.KeyHash, + &token.Type, + &token.CreatedAt, + &token.UpdatedAt, + ) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("static token with ID '%s' not found", tokenID) + } + return nil, fmt.Errorf("failed to get static token: %w", err) + } + + token.Owner = domain.Owner{ + Type: domain.OwnerType(ownerType), + Name: ownerName, + Owner: ownerOwner, + } + + return token, nil } // GetByKeyHash retrieves a static token by its key hash func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) { - // TODO: Implement actual token retrieval by hash - return nil, nil + query := ` + SELECT id, app_id, owner_type, owner_name, owner_owner, + key_hash, type, created_at, updated_at + FROM static_tokens + WHERE key_hash = $1 + ` + + db := r.db.GetDB().(*sql.DB) + row := db.QueryRowContext(ctx, query, keyHash) + + token := &domain.StaticToken{} + var ownerType, ownerName, ownerOwner string + + err := row.Scan( + &token.ID, + &token.AppID, + &ownerType, + &ownerName, + &ownerOwner, + &token.KeyHash, + &token.Type, + &token.CreatedAt, + &token.UpdatedAt, + ) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("static token with hash not found") + } + return nil, fmt.Errorf("failed to get static token by hash: %w", err) + } + + token.Owner = domain.Owner{ + Type: domain.OwnerType(ownerType), + Name: ownerName, + Owner: ownerOwner, + } + + return token, nil } // GetByAppID retrieves all static tokens for an application func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) { - // TODO: Implement actual token listing - return []*domain.StaticToken{}, nil + query := ` + SELECT id, app_id, owner_type, owner_name, owner_owner, + key_hash, type, created_at, updated_at + FROM static_tokens + WHERE app_id = $1 + ORDER BY created_at DESC + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, appID) + if err != nil { + return nil, fmt.Errorf("failed to query static tokens: %w", err) + } + defer rows.Close() + + var tokens []*domain.StaticToken + for rows.Next() { + token := &domain.StaticToken{} + var ownerType, ownerName, ownerOwner string + + err := rows.Scan( + &token.ID, + &token.AppID, + &ownerType, + &ownerName, + &ownerOwner, + &token.KeyHash, + &token.Type, + &token.CreatedAt, + &token.UpdatedAt, + ) + + if err != nil { + return nil, fmt.Errorf("failed to scan static token: %w", err) + } + + token.Owner = domain.Owner{ + Type: domain.OwnerType(ownerType), + Name: ownerName, + Owner: ownerOwner, + } + + tokens = append(tokens, token) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating static tokens: %w", err) + } + + return tokens, nil } // List retrieves static tokens with pagination func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) { - // TODO: Implement actual token listing - return []*domain.StaticToken{}, nil + query := ` + SELECT id, app_id, owner_type, owner_name, owner_owner, + key_hash, type, created_at, updated_at + FROM static_tokens + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + ` + + db := r.db.GetDB().(*sql.DB) + rows, err := db.QueryContext(ctx, query, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query static tokens: %w", err) + } + defer rows.Close() + + var tokens []*domain.StaticToken + for rows.Next() { + token := &domain.StaticToken{} + var ownerType, ownerName, ownerOwner string + + err := rows.Scan( + &token.ID, + &token.AppID, + &ownerType, + &ownerName, + &ownerOwner, + &token.KeyHash, + &token.Type, + &token.CreatedAt, + &token.UpdatedAt, + ) + + if err != nil { + return nil, fmt.Errorf("failed to scan static token: %w", err) + } + + token.Owner = domain.Owner{ + Type: domain.OwnerType(ownerType), + Name: ownerName, + Owner: ownerOwner, + } + + tokens = append(tokens, token) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating static tokens: %w", err) + } + + return tokens, nil } // Delete deletes a static token diff --git a/internal/services/token_service.go b/internal/services/token_service.go index fe512b3..259d4bf 100644 --- a/internal/services/token_service.go +++ b/internal/services/token_service.go @@ -8,17 +8,19 @@ import ( "github.com/google/uuid" "go.uber.org/zap" + "github.com/kms/api-key-service/internal/crypto" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/repository" ) // tokenService implements the TokenService interface type tokenService struct { - tokenRepo repository.StaticTokenRepository - appRepo repository.ApplicationRepository - permRepo repository.PermissionRepository - grantRepo repository.GrantedPermissionRepository - logger *zap.Logger + tokenRepo repository.StaticTokenRepository + appRepo repository.ApplicationRepository + permRepo repository.PermissionRepository + grantRepo repository.GrantedPermissionRepository + tokenGen *crypto.TokenGenerator + logger *zap.Logger } // NewTokenService creates a new token service @@ -27,6 +29,7 @@ func NewTokenService( appRepo repository.ApplicationRepository, permRepo repository.PermissionRepository, grantRepo repository.GrantedPermissionRepository, + hmacKey string, logger *zap.Logger, ) TokenService { return &tokenService{ @@ -34,6 +37,7 @@ func NewTokenService( appRepo: appRepo, permRepo: permRepo, grantRepo: grantRepo, + tokenGen: crypto.NewTokenGenerator(hmacKey), logger: logger, } } @@ -42,10 +46,33 @@ func NewTokenService( func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.CreateStaticTokenRequest, userID string) (*domain.CreateStaticTokenResponse, error) { s.logger.Info("Creating static token", zap.String("app_id", req.AppID), zap.String("user_id", userID)) - // TODO: Validate permissions - // TODO: Validate application exists - // TODO: Generate secure token - // TODO: Grant permissions + // Validate application exists + app, err := s.appRepo.GetByID(ctx, req.AppID) + if err != nil { + s.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", req.AppID)) + return nil, fmt.Errorf("application not found: %w", err) + } + + // Validate permissions exist + validPermissions, err := s.permRepo.ValidatePermissionScopes(ctx, req.Permissions) + if err != nil { + s.logger.Error("Failed to validate permissions", zap.Error(err)) + return nil, fmt.Errorf("failed to validate permissions: %w", err) + } + + if len(validPermissions) != len(req.Permissions) { + s.logger.Warn("Some permissions are invalid", + zap.Strings("requested", req.Permissions), + zap.Strings("valid", validPermissions)) + return nil, fmt.Errorf("some requested permissions are invalid") + } + + // Generate secure token + tokenInfo, err := s.tokenGen.GenerateTokenWithInfo() + if err != nil { + s.logger.Error("Failed to generate secure token", zap.Error(err)) + return nil, fmt.Errorf("failed to generate token: %w", err) + } tokenID := uuid.New() now := time.Now() @@ -54,32 +81,62 @@ func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.Create token := &domain.StaticToken{ ID: tokenID, AppID: req.AppID, - Owner: domain.Owner{ - Type: domain.OwnerTypeIndividual, - Name: userID, - Owner: userID, - }, - KeyHash: "placeholder-hash-" + tokenID.String(), - Type: "hmac", + Owner: req.Owner, + KeyHash: tokenInfo.Hash, + Type: "hmac", CreatedAt: now, UpdatedAt: now, } // Save the token to the database - err := s.tokenRepo.Create(ctx, token) + err = s.tokenRepo.Create(ctx, token) if err != nil { s.logger.Error("Failed to create token in database", zap.Error(err), zap.String("token_id", tokenID.String())) return nil, fmt.Errorf("failed to create token: %w", err) } + // Grant permissions to the token + var grants []*domain.GrantedPermission + for _, permScope := range validPermissions { + // Get permission by scope to get the ID + perm, err := s.permRepo.GetAvailablePermissionByScope(ctx, permScope) + if err != nil { + s.logger.Error("Failed to get permission by scope", zap.Error(err), zap.String("scope", permScope)) + continue + } + + grant := &domain.GrantedPermission{ + ID: uuid.New(), + TokenType: domain.TokenTypeStatic, + TokenID: tokenID, + PermissionID: perm.ID, + Scope: permScope, + CreatedBy: userID, + } + grants = append(grants, grant) + } + + if len(grants) > 0 { + err = s.grantRepo.GrantPermissions(ctx, grants) + if err != nil { + s.logger.Error("Failed to grant permissions", zap.Error(err)) + // Clean up the token if permission granting fails + s.tokenRepo.Delete(ctx, tokenID) + return nil, fmt.Errorf("failed to grant permissions: %w", err) + } + } + response := &domain.CreateStaticTokenResponse{ ID: tokenID, - Token: "static-token-placeholder-" + tokenID.String(), - Permissions: req.Permissions, + Token: tokenInfo.Token, // Return the actual token only once + Permissions: validPermissions, CreatedAt: now, } - s.logger.Info("Static token created successfully", zap.String("token_id", tokenID.String())) + s.logger.Info("Static token created successfully", + zap.String("token_id", tokenID.String()), + zap.String("app_id", app.AppID), + zap.Strings("permissions", validPermissions)) return response, nil } diff --git a/test/e2e_test.sh b/test/e2e_test.sh index 22187a1..c062e8b 100755 --- a/test/e2e_test.sh +++ b/test/e2e_test.sh @@ -267,10 +267,12 @@ test_token_endpoints() { -H "X-User-Email: $USER_EMAIL" \ -H "X-User-ID: $USER_ID" \ -d '{ - "name": "Test Static Token for Deletion", - "description": "A test static token for deletion test", - "permissions": ["read"], - "expires_at": "2025-12-31T23:59:59Z" + "owner": { + "type": "individual", + "name": "Test Token Owner", + "owner": "test-token@example.com" + }, + "permissions": ["repo.read", "repo.write"] }' 2>/dev/null || echo -e "\n000") local token_status_code=$(echo "$token_response" | tail -n1) @@ -282,10 +284,12 @@ test_token_endpoints() { -H "X-User-Email: $USER_EMAIL" \ -H "X-User-ID: $USER_ID" \ -d '{ - "name": "Test Static Token", - "description": "A test static token", - "permissions": ["read"], - "expires_at": "2025-12-31T23:59:59Z" + "owner": { + "type": "individual", + "name": "Test Token Owner", + "owner": "test-token@example.com" + }, + "permissions": ["repo.read", "repo.write"] }' # Extract token_id from the first response for deletion test diff --git a/test/integration_test.go b/test/integration_test.go index 45887ae..423dcaa 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -96,7 +96,7 @@ func (suite *IntegrationTestSuite) setupServer() { // Initialize services appService := services.NewApplicationService(appRepo, logger) - tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger) + tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, suite.cfg.GetString("INTERNAL_HMAC_KEY"), logger) authService := services.NewAuthenticationService(suite.cfg, logger) // Initialize handlers diff --git a/test/mock_repositories.go b/test/mock_repositories.go index 5b5294e..058859c 100644 --- a/test/mock_repositories.go +++ b/test/mock_repositories.go @@ -460,14 +460,14 @@ func (m *MockPermissionRepository) ValidatePermissionScopes(ctx context.Context, m.mu.RLock() defer m.mu.RUnlock() - var invalid []string + var valid []string for _, scope := range scopes { - if _, exists := m.scopeIndex[scope]; !exists { - invalid = append(invalid, scope) + if _, exists := m.scopeIndex[scope]; exists { + valid = append(valid, scope) } } - return invalid, nil + return valid, nil } func (m *MockPermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {