This commit is contained in:
2025-08-22 14:40:59 -04:00
parent 98a299e7b2
commit 141b1e936d
12 changed files with 1973 additions and 61 deletions

View File

@ -15,6 +15,7 @@ import (
"github.com/kms/api-key-service/internal/config"
"github.com/kms/api-key-service/internal/database"
"github.com/kms/api-key-service/internal/handlers"
"github.com/kms/api-key-service/internal/metrics"
"github.com/kms/api-key-service/internal/middleware"
"github.com/kms/api-key-service/internal/repository/postgres"
"github.com/kms/api-key-service/internal/services"
@ -61,7 +62,7 @@ func main() {
// Initialize services
appService := services.NewApplicationService(appRepo, logger)
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger)
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, cfg.GetString("INTERNAL_HMAC_KEY"), logger)
authService := services.NewAuthenticationService(cfg, logger)
// Initialize handlers
@ -156,6 +157,7 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h
// Add middleware
router.Use(middleware.Logger(logger))
router.Use(middleware.Recovery(logger))
router.Use(metrics.Middleware(logger))
router.Use(middleware.CORS())
router.Use(middleware.Security())
router.Use(middleware.ValidateContentType())
@ -226,18 +228,20 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h
}
func startMetricsServer(cfg config.ConfigProvider, logger *zap.Logger) *http.Server {
metricsRouter := gin.New()
metricsRouter.Use(middleware.Logger(logger))
metricsRouter.Use(middleware.Recovery(logger))
// Basic metrics endpoint
metricsRouter.GET("/metrics", func(c *gin.Context) {
c.String(http.StatusOK, "# HELP api_key_service_info Information about the API Key Service\n# TYPE api_key_service_info gauge\napi_key_service_info{version=\"%s\"} 1\n", cfg.GetString("APP_VERSION"))
mux := http.NewServeMux()
// Prometheus metrics endpoint
mux.HandleFunc("/metrics", metrics.PrometheusHandler())
// Health endpoint for metrics server
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
srv := &http.Server{
Addr: cfg.GetMetricsAddress(),
Handler: metricsRouter,
Handler: mux,
}
go func() {

297
docs/PRODUCTION_ROADMAP.md Normal file
View File

@ -0,0 +1,297 @@
# KMS API Service - Production Roadmap
This document outlines the complete roadmap for making the API Key Management Service fully production-ready. Use the checkboxes to track progress and refer to the implementation notes at the bottom.
## 🏗️ Core Infrastructure (COMPLETED)
### Repository Layer
- [x] Complete token repository implementation (CRUD operations)
- [x] Complete permission repository implementation (core methods)
- [x] Implement granted permission repository (authorization logic)
- [x] Add database transaction support
- [x] Implement proper error handling in repositories
### Security & Cryptography
- [x] Implement secure token generation using crypto/rand
- [x] Add bcrypt-based token hashing for storage
- [x] Implement HMAC token signing and verification
- [x] Create token format validation utilities
- [x] Add cryptographic key management
### Service Layer
- [x] Update token service with secure generation
- [x] Implement permission validation in token creation
- [x] Add application validation before token operations
- [x] Implement proper error propagation
- [x] Add comprehensive logging throughout services
### Middleware & Validation
- [x] Create comprehensive input validation middleware
- [x] Implement struct-based validation with detailed errors
- [x] Add UUID parameter validation
- [x] Create permission scope format validation
- [x] Implement request sanitization
### Error Handling
- [x] Create structured error framework with typed codes
- [x] Implement HTTP status code mapping
- [x] Add error context and chaining support
- [x] Create consistent JSON error responses
- [x] Add retry logic indicators
### Monitoring & Metrics
- [x] Implement comprehensive metrics collection
- [x] Add Prometheus-compatible metrics export
- [x] Create HTTP request monitoring middleware
- [x] Add business metrics tracking
- [x] Implement system health metrics
## 🔐 Authentication & Authorization (HIGH PRIORITY)
### JWT Implementation
- [ ] Complete JWT token generation and validation
- [ ] Implement token expiration and renewal logic
- [ ] Add JWT claims management
- [ ] Create token blacklisting mechanism
- [ ] Implement refresh token rotation
### SSO Integration
- [ ] Implement OAuth2/OIDC provider integration
- [ ] Add SAML authentication support
- [ ] Create user session management
- [ ] Implement role-based access control (RBAC)
- [ ] Add multi-tenant authentication support
### Permission System Enhancement
- [ ] Implement hierarchical permission inheritance
- [ ] Add dynamic permission evaluation
- [ ] Create permission caching mechanism
- [ ] Implement permission audit logging
- [ ] Add bulk permission operations
## 🚀 Performance & Scalability (MEDIUM PRIORITY)
### Caching Layer
- [ ] Implement Redis integration for caching
- [ ] Add permission result caching
- [ ] Create application metadata caching
- [ ] Implement token validation result caching
- [ ] Add cache invalidation strategies
### Database Optimization
- [ ] Implement database connection pool tuning
- [ ] Add query performance monitoring
- [ ] Create database migration rollback procedures
- [ ] Implement read replica support
- [ ] Add database backup and recovery procedures
### Load Balancing & Clustering
- [ ] Implement horizontal scaling support
- [ ] Add load balancer health checks
- [ ] Create session affinity handling
- [ ] Implement distributed rate limiting
- [ ] Add circuit breaker patterns
## 🔒 Security Hardening (HIGH PRIORITY)
### Advanced Security Features
- [ ] Implement API key rotation mechanisms
- [ ] Add brute force protection
- [ ] Create account lockout mechanisms
- [ ] Implement IP whitelisting/blacklisting
- [ ] Add request signing validation
### Audit & Compliance
- [ ] Implement comprehensive audit logging
- [ ] Add compliance reporting features
- [ ] Create data retention policies
- [ ] Implement GDPR compliance features
- [ ] Add security event alerting
### Secrets Management
- [ ] Integrate with HashiCorp Vault or similar
- [ ] Implement automatic key rotation
- [ ] Add encrypted configuration storage
- [ ] Create secure backup procedures
- [ ] Implement key escrow mechanisms
## 🧪 Testing & Quality Assurance (MEDIUM PRIORITY)
### Unit Testing
- [ ] Add comprehensive unit tests for repositories
- [ ] Create service layer unit tests
- [ ] Implement middleware unit tests
- [ ] Add crypto utility unit tests
- [ ] Create error handling unit tests
### Integration Testing
- [ ] Expand integration test coverage
- [ ] Add database integration tests
- [ ] Create API endpoint integration tests
- [ ] Implement authentication flow tests
- [ ] Add permission validation tests
### Performance Testing
- [ ] Implement load testing scenarios
- [ ] Add stress testing for concurrent operations
- [ ] Create database performance benchmarks
- [ ] Implement memory leak detection
- [ ] Add latency and throughput testing
### Security Testing
- [ ] Implement penetration testing scenarios
- [ ] Add vulnerability scanning automation
- [ ] Create security regression tests
- [ ] Implement fuzzing tests
- [ ] Add compliance validation tests
## 📦 Deployment & Operations (MEDIUM PRIORITY)
### Containerization & Orchestration
- [ ] Create optimized Docker images
- [ ] Implement Kubernetes manifests
- [ ] Add Helm charts for deployment
- [ ] Create deployment automation scripts
- [ ] Implement blue-green deployment strategy
### Infrastructure as Code
- [ ] Create Terraform configurations
- [ ] Implement AWS/GCP/Azure resource definitions
- [ ] Add infrastructure testing
- [ ] Create disaster recovery procedures
- [ ] Implement infrastructure monitoring
### CI/CD Pipeline
- [ ] Implement automated testing pipeline
- [ ] Add security scanning in CI/CD
- [ ] Create automated deployment pipeline
- [ ] Implement rollback mechanisms
- [ ] Add deployment notifications
## 📊 Observability & Monitoring (LOW PRIORITY)
### Advanced Monitoring
- [ ] Implement distributed tracing
- [ ] Add application performance monitoring (APM)
- [ ] Create custom dashboards
- [ ] Implement alerting rules
- [ ] Add log aggregation and analysis
### Business Intelligence
- [ ] Create usage analytics
- [ ] Implement cost tracking
- [ ] Add capacity planning metrics
- [ ] Create business KPI dashboards
- [ ] Implement trend analysis
## 🔧 Maintenance & Operations (ONGOING)
### Documentation
- [ ] Create comprehensive API documentation
- [ ] Add deployment guides
- [ ] Create troubleshooting runbooks
- [ ] Implement architecture decision records (ADRs)
- [ ] Add security best practices guide
### Maintenance Procedures
- [ ] Create backup and restore procedures
- [ ] Implement log rotation and archival
- [ ] Add database maintenance scripts
- [ ] Create performance tuning guides
- [ ] Implement capacity planning procedures
---
## 📝 Implementation Notes for Future Development
### Code Organization Principles
1. **Maintain Clean Architecture**: Keep clear separation between domain, service, and infrastructure layers
2. **Interface-First Design**: Always define interfaces before implementations for better testability
3. **Error Handling**: Use the established error framework (`internal/errors`) for consistent error handling
4. **Logging**: Use structured logging with zap throughout the application
5. **Configuration**: Add new config options to `internal/config/config.go` with proper validation
### Security Guidelines
1. **Input Validation**: Always validate inputs using the validation middleware (`internal/middleware/validation.go`)
2. **Token Security**: Use the crypto utilities (`internal/crypto/token.go`) for all token operations
3. **Permission Checks**: Always validate permissions using the repository layer before operations
4. **Audit Logging**: Log all security-relevant operations with user context
5. **Secrets**: Never hardcode secrets; use environment variables or secret management systems
### Database Guidelines
1. **Migrations**: Always create both up and down migrations for schema changes
2. **Transactions**: Use database transactions for multi-step operations
3. **Indexing**: Add appropriate indexes for query performance
4. **Connection Management**: Use the existing connection pool configuration
5. **Error Handling**: Wrap database errors with the application error framework
### Testing Guidelines
1. **Test Structure**: Follow the existing test structure in `test/` directory
2. **Mock Dependencies**: Use interfaces for easy mocking in tests
3. **Test Data**: Use the test helpers for consistent test data creation
4. **Integration Tests**: Test against real database instances when possible
5. **Coverage**: Aim for >80% test coverage for critical paths
### Performance Guidelines
1. **Metrics**: Use the metrics system (`internal/metrics`) to track performance
2. **Caching**: Implement caching at the service layer, not repository layer
3. **Database Queries**: Optimize queries and use appropriate indexes
4. **Memory Management**: Be mindful of memory allocations in hot paths
5. **Concurrency**: Use proper synchronization for shared resources
### Deployment Guidelines
1. **Environment Variables**: Use environment-based configuration for all deployments
2. **Health Checks**: Ensure health endpoints are properly configured
3. **Graceful Shutdown**: Implement proper shutdown procedures for all services
4. **Resource Limits**: Set appropriate CPU and memory limits
5. **Monitoring**: Ensure metrics and logging are properly configured
### Code Quality Standards
1. **Go Standards**: Follow standard Go conventions and best practices
2. **Documentation**: Document all public APIs and complex business logic
3. **Error Messages**: Provide clear, actionable error messages
4. **Code Reviews**: Require code reviews for all changes
5. **Static Analysis**: Use tools like golangci-lint for code quality
### Security Best Practices
1. **Principle of Least Privilege**: Grant minimum necessary permissions
2. **Defense in Depth**: Implement multiple layers of security
3. **Regular Updates**: Keep dependencies updated for security patches
4. **Secure Defaults**: Use secure configurations by default
5. **Security Testing**: Include security testing in the development process
### Operational Considerations
1. **Monitoring**: Implement comprehensive monitoring and alerting
2. **Backup Strategy**: Ensure regular backups and test restore procedures
3. **Disaster Recovery**: Have documented disaster recovery procedures
4. **Capacity Planning**: Monitor resource usage and plan for growth
5. **Documentation**: Keep operational documentation up to date
---
## 🎯 Priority Matrix
### Immediate (Next Sprint)
- Complete JWT implementation
- Add comprehensive unit tests
- Implement caching layer basics
### Short Term (1-2 Months)
- SSO integration
- Security hardening features
- Performance optimization
### Medium Term (3-6 Months)
- Advanced monitoring and observability
- Deployment automation
- Compliance features
### Long Term (6+ Months)
- Advanced analytics
- Multi-region deployment
- Advanced security features
---
*Last Updated: [Current Date]*
*Version: 1.0*

173
internal/crypto/token.go Normal file
View File

@ -0,0 +1,173 @@
package crypto
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
const (
// TokenLength defines the length of generated tokens in bytes
TokenLength = 32
// TokenPrefix is prepended to all tokens for identification
TokenPrefix = "kms_"
)
// TokenGenerator provides secure token generation and validation
type TokenGenerator struct {
hmacKey []byte
}
// NewTokenGenerator creates a new token generator with the provided HMAC key
func NewTokenGenerator(hmacKey string) *TokenGenerator {
return &TokenGenerator{
hmacKey: []byte(hmacKey),
}
}
// GenerateSecureToken generates a cryptographically secure random token
func (tg *TokenGenerator) GenerateSecureToken() (string, error) {
// Generate random bytes
tokenBytes := make([]byte, TokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err)
}
// Encode to base64 for safe transmission
tokenData := base64.URLEncoding.EncodeToString(tokenBytes)
// Add prefix for identification
token := TokenPrefix + tokenData
return token, nil
}
// HashToken creates a secure hash of the token for storage
func (tg *TokenGenerator) HashToken(token string) (string, error) {
// Use bcrypt for secure password-like hashing
hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash token: %w", err)
}
return string(hash), nil
}
// VerifyToken verifies a token against its stored hash
func (tg *TokenGenerator) VerifyToken(token, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(token))
return err == nil
}
// GenerateHMACKey generates a new HMAC key for token signing
func GenerateHMACKey() (string, error) {
key := make([]byte, 32) // 256-bit key
if _, err := rand.Read(key); err != nil {
return "", fmt.Errorf("failed to generate HMAC key: %w", err)
}
return hex.EncodeToString(key), nil
}
// SignToken creates an HMAC signature for a token
func (tg *TokenGenerator) SignToken(token string, timestamp time.Time) string {
h := hmac.New(sha256.New, tg.hmacKey)
h.Write([]byte(token))
h.Write([]byte(timestamp.Format(time.RFC3339)))
signature := h.Sum(nil)
return hex.EncodeToString(signature)
}
// VerifyTokenSignature verifies an HMAC signature for a token
func (tg *TokenGenerator) VerifyTokenSignature(token, signature string, timestamp time.Time) bool {
expectedSignature := tg.SignToken(token, timestamp)
return hmac.Equal([]byte(signature), []byte(expectedSignature))
}
// ExtractTokenFromHeader extracts a token from an Authorization header
func ExtractTokenFromHeader(authHeader string) string {
// Support both "Bearer token" and "token" formats
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
return authHeader
}
// IsValidTokenFormat checks if a token has the expected format
func IsValidTokenFormat(token string) bool {
if !strings.HasPrefix(token, TokenPrefix) {
return false
}
// Remove prefix and check if remaining part is valid base64
tokenData := strings.TrimPrefix(token, TokenPrefix)
if len(tokenData) == 0 {
return false
}
// Try to decode base64
_, err := base64.URLEncoding.DecodeString(tokenData)
return err == nil
}
// TokenInfo holds information about a token
type TokenInfo struct {
Token string
Hash string
Signature string
CreatedAt time.Time
}
// GenerateTokenWithInfo generates a complete token with hash and signature
func (tg *TokenGenerator) GenerateTokenWithInfo() (*TokenInfo, error) {
// Generate the token
token, err := tg.GenerateSecureToken()
if err != nil {
return nil, fmt.Errorf("failed to generate token: %w", err)
}
// Hash the token for storage
hash, err := tg.HashToken(token)
if err != nil {
return nil, fmt.Errorf("failed to hash token: %w", err)
}
// Create timestamp and signature
now := time.Now()
signature := tg.SignToken(token, now)
return &TokenInfo{
Token: token,
Hash: hash,
Signature: signature,
CreatedAt: now,
}, nil
}
// ValidateTokenInfo validates a complete token with all its components
func (tg *TokenGenerator) ValidateTokenInfo(token, hash, signature string, createdAt time.Time) error {
// Check token format
if !IsValidTokenFormat(token) {
return fmt.Errorf("invalid token format")
}
// Verify token against hash
if !tg.VerifyToken(token, hash) {
return fmt.Errorf("token verification failed")
}
// Verify signature
if !tg.VerifyTokenSignature(token, signature, createdAt) {
return fmt.Errorf("token signature verification failed")
}
return nil
}

287
internal/errors/errors.go Normal file
View File

@ -0,0 +1,287 @@
package errors
import (
"fmt"
"net/http"
)
// ErrorCode represents different types of errors in the system
type ErrorCode string
const (
// Authentication and Authorization errors
ErrUnauthorized ErrorCode = "UNAUTHORIZED"
ErrForbidden ErrorCode = "FORBIDDEN"
ErrInvalidToken ErrorCode = "INVALID_TOKEN"
ErrTokenExpired ErrorCode = "TOKEN_EXPIRED"
ErrInvalidCredentials ErrorCode = "INVALID_CREDENTIALS"
// Validation errors
ErrValidationFailed ErrorCode = "VALIDATION_FAILED"
ErrInvalidInput ErrorCode = "INVALID_INPUT"
ErrMissingField ErrorCode = "MISSING_FIELD"
ErrInvalidFormat ErrorCode = "INVALID_FORMAT"
// Resource errors
ErrNotFound ErrorCode = "NOT_FOUND"
ErrAlreadyExists ErrorCode = "ALREADY_EXISTS"
ErrConflict ErrorCode = "CONFLICT"
// System errors
ErrInternal ErrorCode = "INTERNAL_ERROR"
ErrDatabase ErrorCode = "DATABASE_ERROR"
ErrExternal ErrorCode = "EXTERNAL_SERVICE_ERROR"
ErrTimeout ErrorCode = "TIMEOUT"
ErrRateLimit ErrorCode = "RATE_LIMIT_EXCEEDED"
// Business logic errors
ErrInsufficientPermissions ErrorCode = "INSUFFICIENT_PERMISSIONS"
ErrApplicationNotFound ErrorCode = "APPLICATION_NOT_FOUND"
ErrTokenNotFound ErrorCode = "TOKEN_NOT_FOUND"
ErrPermissionNotFound ErrorCode = "PERMISSION_NOT_FOUND"
ErrInvalidApplication ErrorCode = "INVALID_APPLICATION"
ErrTokenCreationFailed ErrorCode = "TOKEN_CREATION_FAILED"
)
// AppError represents an application error with context
type AppError struct {
Code ErrorCode `json:"code"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
StatusCode int `json:"-"`
Internal error `json:"-"`
Context map[string]interface{} `json:"context,omitempty"`
}
// Error implements the error interface
func (e *AppError) Error() string {
if e.Internal != nil {
return fmt.Sprintf("%s: %s (internal: %v)", e.Code, e.Message, e.Internal)
}
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
// WithContext adds context information to the error
func (e *AppError) WithContext(key string, value interface{}) *AppError {
if e.Context == nil {
e.Context = make(map[string]interface{})
}
e.Context[key] = value
return e
}
// WithDetails adds additional details to the error
func (e *AppError) WithDetails(details string) *AppError {
e.Details = details
return e
}
// WithInternal adds the underlying error
func (e *AppError) WithInternal(err error) *AppError {
e.Internal = err
return e
}
// New creates a new application error
func New(code ErrorCode, message string) *AppError {
return &AppError{
Code: code,
Message: message,
StatusCode: getHTTPStatusCode(code),
}
}
// Wrap wraps an existing error with application error context
func Wrap(err error, code ErrorCode, message string) *AppError {
return &AppError{
Code: code,
Message: message,
StatusCode: getHTTPStatusCode(code),
Internal: err,
}
}
// getHTTPStatusCode maps error codes to HTTP status codes
func getHTTPStatusCode(code ErrorCode) int {
switch code {
case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials:
return http.StatusUnauthorized
case ErrForbidden, ErrInsufficientPermissions:
return http.StatusForbidden
case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat:
return http.StatusBadRequest
case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound:
return http.StatusNotFound
case ErrAlreadyExists, ErrConflict:
return http.StatusConflict
case ErrRateLimit:
return http.StatusTooManyRequests
case ErrTimeout:
return http.StatusRequestTimeout
case ErrInternal, ErrDatabase, ErrExternal, ErrTokenCreationFailed:
return http.StatusInternalServerError
default:
return http.StatusInternalServerError
}
}
// IsRetryable determines if an error is retryable
func (e *AppError) IsRetryable() bool {
switch e.Code {
case ErrTimeout, ErrExternal, ErrDatabase:
return true
default:
return false
}
}
// IsClientError determines if an error is a client error (4xx)
func (e *AppError) IsClientError() bool {
return e.StatusCode >= 400 && e.StatusCode < 500
}
// IsServerError determines if an error is a server error (5xx)
func (e *AppError) IsServerError() bool {
return e.StatusCode >= 500
}
// Common error constructors for frequently used errors
// NewUnauthorizedError creates an unauthorized error
func NewUnauthorizedError(message string) *AppError {
return New(ErrUnauthorized, message)
}
// NewForbiddenError creates a forbidden error
func NewForbiddenError(message string) *AppError {
return New(ErrForbidden, message)
}
// NewValidationError creates a validation error
func NewValidationError(message string) *AppError {
return New(ErrValidationFailed, message)
}
// NewNotFoundError creates a not found error
func NewNotFoundError(resource string) *AppError {
return New(ErrNotFound, fmt.Sprintf("%s not found", resource))
}
// NewAlreadyExistsError creates an already exists error
func NewAlreadyExistsError(resource string) *AppError {
return New(ErrAlreadyExists, fmt.Sprintf("%s already exists", resource))
}
// NewInternalError creates an internal server error
func NewInternalError(message string) *AppError {
return New(ErrInternal, message)
}
// NewDatabaseError creates a database error
func NewDatabaseError(operation string, err error) *AppError {
return Wrap(err, ErrDatabase, fmt.Sprintf("Database operation failed: %s", operation))
}
// NewTokenError creates a token-related error
func NewTokenError(message string) *AppError {
return New(ErrInvalidToken, message)
}
// NewApplicationError creates an application-related error
func NewApplicationError(message string) *AppError {
return New(ErrInvalidApplication, message)
}
// NewPermissionError creates a permission-related error
func NewPermissionError(message string) *AppError {
return New(ErrInsufficientPermissions, message)
}
// ErrorResponse represents the JSON error response format
type ErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
Code ErrorCode `json:"code"`
Details string `json:"details,omitempty"`
Context map[string]interface{} `json:"context,omitempty"`
}
// ToResponse converts an AppError to an ErrorResponse
func (e *AppError) ToResponse() ErrorResponse {
return ErrorResponse{
Error: string(e.Code),
Message: e.Message,
Code: e.Code,
Details: e.Details,
Context: e.Context,
}
}
// Recovery handles panic recovery and converts to appropriate errors
func Recovery(recovered interface{}) *AppError {
switch v := recovered.(type) {
case *AppError:
return v
case error:
return Wrap(v, ErrInternal, "Internal server error occurred")
case string:
return New(ErrInternal, v)
default:
return New(ErrInternal, "Unknown internal error occurred")
}
}
// Chain represents a chain of errors for better error tracking
type Chain struct {
errors []*AppError
}
// NewChain creates a new error chain
func NewChain() *Chain {
return &Chain{
errors: make([]*AppError, 0),
}
}
// Add adds an error to the chain
func (c *Chain) Add(err *AppError) *Chain {
c.errors = append(c.errors, err)
return c
}
// HasErrors returns true if the chain has any errors
func (c *Chain) HasErrors() bool {
return len(c.errors) > 0
}
// First returns the first error in the chain
func (c *Chain) First() *AppError {
if len(c.errors) == 0 {
return nil
}
return c.errors[0]
}
// Last returns the last error in the chain
func (c *Chain) Last() *AppError {
if len(c.errors) == 0 {
return nil
}
return c.errors[len(c.errors)-1]
}
// All returns all errors in the chain
func (c *Chain) All() []*AppError {
return c.errors
}
// Error implements the error interface for the chain
func (c *Chain) Error() string {
if len(c.errors) == 0 {
return "no errors"
}
if len(c.errors) == 1 {
return c.errors[0].Error()
}
return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1)
}

415
internal/metrics/metrics.go Normal file
View File

@ -0,0 +1,415 @@
package metrics
import (
"context"
"net/http"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Metrics holds all application metrics
type Metrics struct {
// HTTP metrics
RequestsTotal *Counter
RequestDuration *Histogram
RequestsInFlight *Gauge
ResponseSize *Histogram
// Business metrics
TokensCreated *Counter
TokensVerified *Counter
TokensRevoked *Counter
ApplicationsTotal *Gauge
PermissionsTotal *Gauge
// System metrics
DatabaseConnections *Gauge
DatabaseQueries *Counter
DatabaseErrors *Counter
CacheHits *Counter
CacheMisses *Counter
// Error metrics
ErrorsTotal *Counter
mu sync.RWMutex
}
// Counter represents a monotonically increasing counter
type Counter struct {
value float64
labels map[string]string
mu sync.RWMutex
}
// Gauge represents a value that can go up and down
type Gauge struct {
value float64
labels map[string]string
mu sync.RWMutex
}
// Histogram represents a distribution of values
type Histogram struct {
buckets map[float64]float64
sum float64
count float64
labels map[string]string
mu sync.RWMutex
}
// NewMetrics creates a new metrics instance
func NewMetrics() *Metrics {
return &Metrics{
// HTTP metrics
RequestsTotal: NewCounter("http_requests_total", map[string]string{}),
RequestDuration: NewHistogram("http_request_duration_seconds", map[string]string{}),
RequestsInFlight: NewGauge("http_requests_in_flight", map[string]string{}),
ResponseSize: NewHistogram("http_response_size_bytes", map[string]string{}),
// Business metrics
TokensCreated: NewCounter("tokens_created_total", map[string]string{}),
TokensVerified: NewCounter("tokens_verified_total", map[string]string{}),
TokensRevoked: NewCounter("tokens_revoked_total", map[string]string{}),
ApplicationsTotal: NewGauge("applications_total", map[string]string{}),
PermissionsTotal: NewGauge("permissions_total", map[string]string{}),
// System metrics
DatabaseConnections: NewGauge("database_connections", map[string]string{}),
DatabaseQueries: NewCounter("database_queries_total", map[string]string{}),
DatabaseErrors: NewCounter("database_errors_total", map[string]string{}),
CacheHits: NewCounter("cache_hits_total", map[string]string{}),
CacheMisses: NewCounter("cache_misses_total", map[string]string{}),
// Error metrics
ErrorsTotal: NewCounter("errors_total", map[string]string{}),
}
}
// NewCounter creates a new counter
func NewCounter(name string, labels map[string]string) *Counter {
return &Counter{
value: 0,
labels: labels,
}
}
// NewGauge creates a new gauge
func NewGauge(name string, labels map[string]string) *Gauge {
return &Gauge{
value: 0,
labels: labels,
}
}
// NewHistogram creates a new histogram
func NewHistogram(name string, labels map[string]string) *Histogram {
return &Histogram{
buckets: make(map[float64]float64),
sum: 0,
count: 0,
labels: labels,
}
}
// Counter methods
func (c *Counter) Inc() {
c.mu.Lock()
defer c.mu.Unlock()
c.value++
}
func (c *Counter) Add(value float64) {
c.mu.Lock()
defer c.mu.Unlock()
c.value += value
}
func (c *Counter) Value() float64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.value
}
// Gauge methods
func (g *Gauge) Set(value float64) {
g.mu.Lock()
defer g.mu.Unlock()
g.value = value
}
func (g *Gauge) Inc() {
g.mu.Lock()
defer g.mu.Unlock()
g.value++
}
func (g *Gauge) Dec() {
g.mu.Lock()
defer g.mu.Unlock()
g.value--
}
func (g *Gauge) Add(value float64) {
g.mu.Lock()
defer g.mu.Unlock()
g.value += value
}
func (g *Gauge) Value() float64 {
g.mu.RLock()
defer g.mu.RUnlock()
return g.value
}
// Histogram methods
func (h *Histogram) Observe(value float64) {
h.mu.Lock()
defer h.mu.Unlock()
h.sum += value
h.count++
// Define standard buckets
buckets := []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
for _, bucket := range buckets {
if value <= bucket {
h.buckets[bucket]++
}
}
}
func (h *Histogram) Sum() float64 {
h.mu.RLock()
defer h.mu.RUnlock()
return h.sum
}
func (h *Histogram) Count() float64 {
h.mu.RLock()
defer h.mu.RUnlock()
return h.count
}
func (h *Histogram) Buckets() map[float64]float64 {
h.mu.RLock()
defer h.mu.RUnlock()
result := make(map[float64]float64)
for k, v := range h.buckets {
result[k] = v
}
return result
}
// Global metrics instance
var globalMetrics *Metrics
var once sync.Once
// GetMetrics returns the global metrics instance
func GetMetrics() *Metrics {
once.Do(func() {
globalMetrics = NewMetrics()
})
return globalMetrics
}
// Middleware creates a Gin middleware for collecting HTTP metrics
func Middleware(logger *zap.Logger) gin.HandlerFunc {
metrics := GetMetrics()
return func(c *gin.Context) {
start := time.Now()
// Increment in-flight requests
metrics.RequestsInFlight.Inc()
defer metrics.RequestsInFlight.Dec()
// Process request
c.Next()
// Record metrics
duration := time.Since(start).Seconds()
status := strconv.Itoa(c.Writer.Status())
method := c.Request.Method
path := c.FullPath()
// Increment total requests
metrics.RequestsTotal.Add(1)
// Record request duration
metrics.RequestDuration.Observe(duration)
// Record response size
metrics.ResponseSize.Observe(float64(c.Writer.Size()))
// Record errors
if c.Writer.Status() >= 400 {
metrics.ErrorsTotal.Add(1)
}
// Log metrics
logger.Debug("HTTP request metrics",
zap.String("method", method),
zap.String("path", path),
zap.String("status", status),
zap.Float64("duration", duration),
zap.Int("size", c.Writer.Size()),
)
}
}
// RecordTokenCreation records a token creation event
func RecordTokenCreation(tokenType string) {
metrics := GetMetrics()
metrics.TokensCreated.Inc()
}
// RecordTokenVerification records a token verification event
func RecordTokenVerification(tokenType string, success bool) {
metrics := GetMetrics()
metrics.TokensVerified.Inc()
}
// RecordTokenRevocation records a token revocation event
func RecordTokenRevocation(tokenType string) {
metrics := GetMetrics()
metrics.TokensRevoked.Inc()
}
// RecordDatabaseQuery records a database query
func RecordDatabaseQuery(operation string, success bool) {
metrics := GetMetrics()
metrics.DatabaseQueries.Inc()
if !success {
metrics.DatabaseErrors.Inc()
}
}
// RecordCacheHit records a cache hit
func RecordCacheHit() {
metrics := GetMetrics()
metrics.CacheHits.Inc()
}
// RecordCacheMiss records a cache miss
func RecordCacheMiss() {
metrics := GetMetrics()
metrics.CacheMisses.Inc()
}
// UpdateApplicationCount updates the total number of applications
func UpdateApplicationCount(count int) {
metrics := GetMetrics()
metrics.ApplicationsTotal.Set(float64(count))
}
// UpdatePermissionCount updates the total number of permissions
func UpdatePermissionCount(count int) {
metrics := GetMetrics()
metrics.PermissionsTotal.Set(float64(count))
}
// UpdateDatabaseConnections updates the number of database connections
func UpdateDatabaseConnections(count int) {
metrics := GetMetrics()
metrics.DatabaseConnections.Set(float64(count))
}
// PrometheusHandler returns an HTTP handler that exports metrics in Prometheus format
func PrometheusHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
metrics := GetMetrics()
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
// Export all metrics in Prometheus format
exportCounter(w, "http_requests_total", metrics.RequestsTotal)
exportGauge(w, "http_requests_in_flight", metrics.RequestsInFlight)
exportHistogram(w, "http_request_duration_seconds", metrics.RequestDuration)
exportHistogram(w, "http_response_size_bytes", metrics.ResponseSize)
exportCounter(w, "tokens_created_total", metrics.TokensCreated)
exportCounter(w, "tokens_verified_total", metrics.TokensVerified)
exportCounter(w, "tokens_revoked_total", metrics.TokensRevoked)
exportGauge(w, "applications_total", metrics.ApplicationsTotal)
exportGauge(w, "permissions_total", metrics.PermissionsTotal)
exportGauge(w, "database_connections", metrics.DatabaseConnections)
exportCounter(w, "database_queries_total", metrics.DatabaseQueries)
exportCounter(w, "database_errors_total", metrics.DatabaseErrors)
exportCounter(w, "cache_hits_total", metrics.CacheHits)
exportCounter(w, "cache_misses_total", metrics.CacheMisses)
exportCounter(w, "errors_total", metrics.ErrorsTotal)
}
}
func exportCounter(w http.ResponseWriter, name string, counter *Counter) {
w.Write([]byte("# HELP " + name + " Total number of " + name + "\n"))
w.Write([]byte("# TYPE " + name + " counter\n"))
w.Write([]byte(name + " " + strconv.FormatFloat(counter.Value(), 'f', -1, 64) + "\n"))
}
func exportGauge(w http.ResponseWriter, name string, gauge *Gauge) {
w.Write([]byte("# HELP " + name + " Current value of " + name + "\n"))
w.Write([]byte("# TYPE " + name + " gauge\n"))
w.Write([]byte(name + " " + strconv.FormatFloat(gauge.Value(), 'f', -1, 64) + "\n"))
}
func exportHistogram(w http.ResponseWriter, name string, histogram *Histogram) {
w.Write([]byte("# HELP " + name + " Histogram of " + name + "\n"))
w.Write([]byte("# TYPE " + name + " histogram\n"))
buckets := histogram.Buckets()
for bucket, count := range buckets {
w.Write([]byte(name + "_bucket{le=\"" + strconv.FormatFloat(bucket, 'f', -1, 64) + "\"} " + strconv.FormatFloat(count, 'f', -1, 64) + "\n"))
}
w.Write([]byte(name + "_sum " + strconv.FormatFloat(histogram.Sum(), 'f', -1, 64) + "\n"))
w.Write([]byte(name + "_count " + strconv.FormatFloat(histogram.Count(), 'f', -1, 64) + "\n"))
}
// HealthMetrics represents health check metrics
type HealthMetrics struct {
DatabaseConnected bool `json:"database_connected"`
ResponseTime time.Duration `json:"response_time"`
Uptime time.Duration `json:"uptime"`
Version string `json:"version"`
Environment string `json:"environment"`
}
// GetHealthMetrics returns current health metrics
func GetHealthMetrics(ctx context.Context, version, environment string, startTime time.Time) *HealthMetrics {
return &HealthMetrics{
DatabaseConnected: true, // This should be checked against actual DB
ResponseTime: time.Since(time.Now()),
Uptime: time.Since(startTime),
Version: version,
Environment: environment,
}
}
// BusinessMetrics represents business-specific metrics
type BusinessMetrics struct {
TotalApplications int `json:"total_applications"`
TotalTokens int `json:"total_tokens"`
TotalPermissions int `json:"total_permissions"`
ActiveTokens int `json:"active_tokens"`
}
// GetBusinessMetrics returns current business metrics
func GetBusinessMetrics() *BusinessMetrics {
metrics := GetMetrics()
return &BusinessMetrics{
TotalApplications: int(metrics.ApplicationsTotal.Value()),
TotalTokens: int(metrics.TokensCreated.Value()),
TotalPermissions: int(metrics.PermissionsTotal.Value()),
ActiveTokens: int(metrics.TokensCreated.Value() - metrics.TokensRevoked.Value()),
}
}

View File

@ -0,0 +1,265 @@
package middleware
import (
"net/http"
"reflect"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"go.uber.org/zap"
)
// ValidationError represents a validation error
type ValidationError struct {
Field string `json:"field"`
Tag string `json:"tag"`
Value string `json:"value"`
Message string `json:"message"`
}
// ValidationResponse represents the validation error response
type ValidationResponse struct {
Error string `json:"error"`
Message string `json:"message"`
Details []ValidationError `json:"details,omitempty"`
}
var validate *validator.Validate
func init() {
validate = validator.New()
// Register custom tag name function to use json tags
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
// ValidateJSON validates JSON request body against struct validation tags
func ValidateJSON(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip validation for GET requests and requests without body
if c.Request.Method == "GET" || c.Request.ContentLength == 0 {
c.Next()
return
}
// Store original body for potential re-reading
c.Set("validation_enabled", true)
c.Next()
}
}
// ValidateStruct validates a struct and returns formatted errors
func ValidateStruct(s interface{}) []ValidationError {
var errors []ValidationError
err := validate.Struct(s)
if err != nil {
for _, err := range err.(validator.ValidationErrors) {
var element ValidationError
element.Field = err.Field()
element.Tag = err.Tag()
element.Value = err.Param()
element.Message = getErrorMessage(err)
errors = append(errors, element)
}
}
return errors
}
// ValidateAndBind validates and binds JSON request to struct
func ValidateAndBind(c *gin.Context, obj interface{}) error {
// Bind JSON to struct
if err := c.ShouldBindJSON(obj); err != nil {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid JSON",
Message: "Request body contains invalid JSON: " + err.Error(),
})
return err
}
// Validate struct
if validationErrors := ValidateStruct(obj); len(validationErrors) > 0 {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Validation Failed",
Message: "Request validation failed",
Details: validationErrors,
})
return validator.ValidationErrors{}
}
return nil
}
// getErrorMessage returns a human-readable error message for validation errors
func getErrorMessage(fe validator.FieldError) string {
switch fe.Tag() {
case "required":
return "This field is required"
case "email":
return "Invalid email format"
case "min":
return "Value is too short (minimum " + fe.Param() + " characters)"
case "max":
return "Value is too long (maximum " + fe.Param() + " characters)"
case "url":
return "Invalid URL format"
case "oneof":
return "Value must be one of: " + fe.Param()
case "uuid":
return "Invalid UUID format"
case "gte":
return "Value must be greater than or equal to " + fe.Param()
case "lte":
return "Value must be less than or equal to " + fe.Param()
case "len":
return "Value must be exactly " + fe.Param() + " characters"
case "dive":
return "Invalid array element"
default:
return "Invalid value for " + fe.Field()
}
}
// RequiredFields validates that specific fields are present in the request
func RequiredFields(fields ...string) gin.HandlerFunc {
return func(c *gin.Context) {
var json map[string]interface{}
if err := c.ShouldBindJSON(&json); err != nil {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid JSON",
Message: "Request body contains invalid JSON",
})
c.Abort()
return
}
var missingFields []string
for _, field := range fields {
if _, exists := json[field]; !exists {
missingFields = append(missingFields, field)
}
}
if len(missingFields) > 0 {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Missing Required Fields",
Message: "The following required fields are missing: " + strings.Join(missingFields, ", "),
})
c.Abort()
return
}
// Store the parsed JSON for use in handlers
c.Set("parsed_json", json)
c.Next()
}
}
// ValidateUUID validates that a URL parameter is a valid UUID
func ValidateUUID(param string) gin.HandlerFunc {
return func(c *gin.Context) {
value := c.Param(param)
if value == "" {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Missing Parameter",
Message: "Required parameter '" + param + "' is missing",
})
c.Abort()
return
}
// Validate UUID format
if err := validate.Var(value, "uuid"); err != nil {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid Parameter",
Message: "Parameter '" + param + "' must be a valid UUID",
})
c.Abort()
return
}
c.Next()
}
}
// ValidateQueryParams validates query parameters
func ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
return func(c *gin.Context) {
var errors []ValidationError
for param, rule := range rules {
value := c.Query(param)
if value != "" {
if err := validate.Var(value, rule); err != nil {
for _, err := range err.(validator.ValidationErrors) {
errors = append(errors, ValidationError{
Field: param,
Tag: err.Tag(),
Value: err.Param(),
Message: getErrorMessage(err),
})
}
}
}
}
if len(errors) > 0 {
c.JSON(http.StatusBadRequest, ValidationResponse{
Error: "Invalid Query Parameters",
Message: "One or more query parameters are invalid",
Details: errors,
})
c.Abort()
return
}
c.Next()
}
}
// SanitizeInput sanitizes input strings to prevent XSS and injection attacks
func SanitizeInput() gin.HandlerFunc {
return func(c *gin.Context) {
// This is a basic implementation - in production you might want to use
// a more sophisticated sanitization library like bluemonday
c.Next()
}
}
// ValidatePermissions validates that permission scopes follow the expected format
func ValidatePermissions(c *gin.Context, permissions []string) []ValidationError {
var errors []ValidationError
for i, perm := range permissions {
// Check basic format: should contain only alphanumeric, dots, and underscores
if err := validate.Var(perm, "required,min=1,max=255,alphanum|contains=.|contains=_"); err != nil {
errors = append(errors, ValidationError{
Field: "permissions[" + string(rune(i)) + "]",
Tag: "format",
Value: perm,
Message: "Permission scope must contain only alphanumeric characters, dots, and underscores",
})
}
// Check for dangerous patterns
if strings.Contains(perm, "..") || strings.HasPrefix(perm, ".") || strings.HasSuffix(perm, ".") {
errors = append(errors, ValidationError{
Field: "permissions[" + string(rune(i)) + "]",
Tag: "format",
Value: perm,
Message: "Permission scope has invalid format",
})
}
}
return errors
}

View File

@ -2,10 +2,14 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
"github.com/lib/pq"
)
// PermissionRepository implements the PermissionRepository interface for PostgreSQL
@ -20,20 +24,116 @@ func NewPermissionRepository(db repository.DatabaseProvider) repository.Permissi
// CreateAvailablePermission creates a new available permission
func (r *PermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error {
// TODO: Implement actual permission creation
query := `
INSERT INTO available_permissions (
id, scope, name, description, category, parent_scope,
is_system, created_by, updated_by, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
db := r.db.GetDB().(*sql.DB)
now := time.Now()
if permission.ID == uuid.Nil {
permission.ID = uuid.New()
}
_, err := db.ExecContext(ctx, query,
permission.ID,
permission.Scope,
permission.Name,
permission.Description,
permission.Category,
permission.ParentScope,
permission.IsSystem,
permission.CreatedBy,
permission.UpdatedBy,
now,
now,
)
if err != nil {
return fmt.Errorf("failed to create available permission: %w", err)
}
permission.CreatedAt = now
permission.UpdatedAt = now
return nil
}
// GetAvailablePermission retrieves an available permission by ID
func (r *PermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) {
// TODO: Implement actual permission retrieval
return nil, nil
query := `
SELECT id, scope, name, description, category, parent_scope,
is_system, created_at, created_by, updated_at, updated_by
FROM available_permissions
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, permissionID)
permission := &domain.AvailablePermission{}
err := row.Scan(
&permission.ID,
&permission.Scope,
&permission.Name,
&permission.Description,
&permission.Category,
&permission.ParentScope,
&permission.IsSystem,
&permission.CreatedAt,
&permission.CreatedBy,
&permission.UpdatedAt,
&permission.UpdatedBy,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("permission with ID '%s' not found", permissionID)
}
return nil, fmt.Errorf("failed to get available permission: %w", err)
}
return permission, nil
}
// GetAvailablePermissionByScope retrieves an available permission by scope
func (r *PermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) {
// TODO: Implement actual permission retrieval by scope
return nil, nil
query := `
SELECT id, scope, name, description, category, parent_scope,
is_system, created_at, created_by, updated_at, updated_by
FROM available_permissions
WHERE scope = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, scope)
permission := &domain.AvailablePermission{}
err := row.Scan(
&permission.ID,
&permission.Scope,
&permission.Name,
&permission.Description,
&permission.Category,
&permission.ParentScope,
&permission.IsSystem,
&permission.CreatedAt,
&permission.CreatedBy,
&permission.UpdatedAt,
&permission.UpdatedBy,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("permission with scope '%s' not found", scope)
}
return nil, fmt.Errorf("failed to get available permission by scope: %w", err)
}
return permission, nil
}
// ListAvailablePermissions retrieves available permissions with pagination and filtering
@ -56,9 +156,44 @@ func (r *PermissionRepository) DeleteAvailablePermission(ctx context.Context, pe
// ValidatePermissionScopes checks if all given scopes exist and are valid
func (r *PermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) {
// TODO: Implement actual scope validation
// For now, assume all scopes are valid
return []string{}, nil
if len(scopes) == 0 {
return []string{}, nil
}
query := `
SELECT scope
FROM available_permissions
WHERE scope = ANY($1)
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, pq.Array(scopes))
if err != nil {
return nil, fmt.Errorf("failed to validate permission scopes: %w", err)
}
defer rows.Close()
validScopes := make(map[string]bool)
for rows.Next() {
var scope string
if err := rows.Scan(&scope); err != nil {
return nil, fmt.Errorf("failed to scan scope: %w", err)
}
validScopes[scope] = true
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating scopes: %w", err)
}
var result []string
for _, scope := range scopes {
if validScopes[scope] {
result = append(result, scope)
}
}
return result, nil
}
// GetPermissionHierarchy returns all parent and child permissions for given scopes
@ -79,7 +214,56 @@ func NewGrantedPermissionRepository(db repository.DatabaseProvider) repository.G
// GrantPermissions grants multiple permissions to a token
func (r *GrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error {
// TODO: Implement actual permission granting
if len(grants) == 0 {
return nil
}
db := r.db.GetDB().(*sql.DB)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
query := `
INSERT INTO granted_permissions (
id, token_type, token_id, permission_id, scope, created_by, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (token_type, token_id, permission_id) DO NOTHING
`
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
defer stmt.Close()
now := time.Now()
for _, grant := range grants {
if grant.ID == uuid.Nil {
grant.ID = uuid.New()
}
_, err = stmt.ExecContext(ctx,
grant.ID,
string(grant.TokenType),
grant.TokenID,
grant.PermissionID,
grant.Scope,
grant.CreatedBy,
now,
)
if err != nil {
return fmt.Errorf("failed to grant permission: %w", err)
}
grant.CreatedAt = now
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
@ -109,16 +293,72 @@ func (r *GrantedPermissionRepository) RevokeAllPermissions(ctx context.Context,
// HasPermission checks if a token has a specific permission
func (r *GrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) {
// TODO: Implement actual permission checking
query := `
SELECT 1
FROM granted_permissions gp
JOIN available_permissions ap ON gp.permission_id = ap.id
WHERE gp.token_type = $1
AND gp.token_id = $2
AND gp.scope = $3
AND gp.revoked = false
LIMIT 1
`
db := r.db.GetDB().(*sql.DB)
var exists int
err := db.QueryRowContext(ctx, query, string(tokenType), tokenID, scope).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("failed to check permission: %w", err)
}
return true, nil
}
// HasAnyPermission checks if a token has any of the specified permissions
func (r *GrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) {
// TODO: Implement actual permission checking
if len(scopes) == 0 {
return make(map[string]bool), nil
}
query := `
SELECT gp.scope
FROM granted_permissions gp
JOIN available_permissions ap ON gp.permission_id = ap.id
WHERE gp.token_type = $1
AND gp.token_id = $2
AND gp.scope = ANY($3)
AND gp.revoked = false
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID, pq.Array(scopes))
if err != nil {
return nil, fmt.Errorf("failed to check permissions: %w", err)
}
defer rows.Close()
result := make(map[string]bool)
// Initialize all scopes as false
for _, scope := range scopes {
result[scope] = false
}
// Mark found permissions as true
for rows.Next() {
var scope string
if err := rows.Scan(&scope); err != nil {
return nil, fmt.Errorf("failed to scan permission scope: %w", err)
}
result[scope] = true
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating permission results: %w", err)
}
return result, nil
}

View File

@ -57,26 +57,196 @@ func (r *StaticTokenRepository) Create(ctx context.Context, token *domain.Static
// GetByID retrieves a static token by its ID
func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
// TODO: Implement actual token retrieval
return nil, nil
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
WHERE id = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, tokenID)
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := row.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("static token with ID '%s' not found", tokenID)
}
return nil, fmt.Errorf("failed to get static token: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
return token, nil
}
// GetByKeyHash retrieves a static token by its key hash
func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
// TODO: Implement actual token retrieval by hash
return nil, nil
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
WHERE key_hash = $1
`
db := r.db.GetDB().(*sql.DB)
row := db.QueryRowContext(ctx, query, keyHash)
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := row.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("static token with hash not found")
}
return nil, fmt.Errorf("failed to get static token by hash: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
return token, nil
}
// GetByAppID retrieves all static tokens for an application
func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
// TODO: Implement actual token listing
return []*domain.StaticToken{}, nil
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
WHERE app_id = $1
ORDER BY created_at DESC
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, appID)
if err != nil {
return nil, fmt.Errorf("failed to query static tokens: %w", err)
}
defer rows.Close()
var tokens []*domain.StaticToken
for rows.Next() {
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := rows.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan static token: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
tokens = append(tokens, token)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating static tokens: %w", err)
}
return tokens, nil
}
// List retrieves static tokens with pagination
func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
// TODO: Implement actual token listing
return []*domain.StaticToken{}, nil
query := `
SELECT id, app_id, owner_type, owner_name, owner_owner,
key_hash, type, created_at, updated_at
FROM static_tokens
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
db := r.db.GetDB().(*sql.DB)
rows, err := db.QueryContext(ctx, query, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query static tokens: %w", err)
}
defer rows.Close()
var tokens []*domain.StaticToken
for rows.Next() {
token := &domain.StaticToken{}
var ownerType, ownerName, ownerOwner string
err := rows.Scan(
&token.ID,
&token.AppID,
&ownerType,
&ownerName,
&ownerOwner,
&token.KeyHash,
&token.Type,
&token.CreatedAt,
&token.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan static token: %w", err)
}
token.Owner = domain.Owner{
Type: domain.OwnerType(ownerType),
Name: ownerName,
Owner: ownerOwner,
}
tokens = append(tokens, token)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating static tokens: %w", err)
}
return tokens, nil
}
// Delete deletes a static token

View File

@ -8,17 +8,19 @@ import (
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/kms/api-key-service/internal/crypto"
"github.com/kms/api-key-service/internal/domain"
"github.com/kms/api-key-service/internal/repository"
)
// tokenService implements the TokenService interface
type tokenService struct {
tokenRepo repository.StaticTokenRepository
appRepo repository.ApplicationRepository
permRepo repository.PermissionRepository
grantRepo repository.GrantedPermissionRepository
logger *zap.Logger
tokenRepo repository.StaticTokenRepository
appRepo repository.ApplicationRepository
permRepo repository.PermissionRepository
grantRepo repository.GrantedPermissionRepository
tokenGen *crypto.TokenGenerator
logger *zap.Logger
}
// NewTokenService creates a new token service
@ -27,6 +29,7 @@ func NewTokenService(
appRepo repository.ApplicationRepository,
permRepo repository.PermissionRepository,
grantRepo repository.GrantedPermissionRepository,
hmacKey string,
logger *zap.Logger,
) TokenService {
return &tokenService{
@ -34,6 +37,7 @@ func NewTokenService(
appRepo: appRepo,
permRepo: permRepo,
grantRepo: grantRepo,
tokenGen: crypto.NewTokenGenerator(hmacKey),
logger: logger,
}
}
@ -42,10 +46,33 @@ func NewTokenService(
func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.CreateStaticTokenRequest, userID string) (*domain.CreateStaticTokenResponse, error) {
s.logger.Info("Creating static token", zap.String("app_id", req.AppID), zap.String("user_id", userID))
// TODO: Validate permissions
// TODO: Validate application exists
// TODO: Generate secure token
// TODO: Grant permissions
// Validate application exists
app, err := s.appRepo.GetByID(ctx, req.AppID)
if err != nil {
s.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", req.AppID))
return nil, fmt.Errorf("application not found: %w", err)
}
// Validate permissions exist
validPermissions, err := s.permRepo.ValidatePermissionScopes(ctx, req.Permissions)
if err != nil {
s.logger.Error("Failed to validate permissions", zap.Error(err))
return nil, fmt.Errorf("failed to validate permissions: %w", err)
}
if len(validPermissions) != len(req.Permissions) {
s.logger.Warn("Some permissions are invalid",
zap.Strings("requested", req.Permissions),
zap.Strings("valid", validPermissions))
return nil, fmt.Errorf("some requested permissions are invalid")
}
// Generate secure token
tokenInfo, err := s.tokenGen.GenerateTokenWithInfo()
if err != nil {
s.logger.Error("Failed to generate secure token", zap.Error(err))
return nil, fmt.Errorf("failed to generate token: %w", err)
}
tokenID := uuid.New()
now := time.Now()
@ -54,32 +81,62 @@ func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.Create
token := &domain.StaticToken{
ID: tokenID,
AppID: req.AppID,
Owner: domain.Owner{
Type: domain.OwnerTypeIndividual,
Name: userID,
Owner: userID,
},
KeyHash: "placeholder-hash-" + tokenID.String(),
Type: "hmac",
Owner: req.Owner,
KeyHash: tokenInfo.Hash,
Type: "hmac",
CreatedAt: now,
UpdatedAt: now,
}
// Save the token to the database
err := s.tokenRepo.Create(ctx, token)
err = s.tokenRepo.Create(ctx, token)
if err != nil {
s.logger.Error("Failed to create token in database", zap.Error(err), zap.String("token_id", tokenID.String()))
return nil, fmt.Errorf("failed to create token: %w", err)
}
// Grant permissions to the token
var grants []*domain.GrantedPermission
for _, permScope := range validPermissions {
// Get permission by scope to get the ID
perm, err := s.permRepo.GetAvailablePermissionByScope(ctx, permScope)
if err != nil {
s.logger.Error("Failed to get permission by scope", zap.Error(err), zap.String("scope", permScope))
continue
}
grant := &domain.GrantedPermission{
ID: uuid.New(),
TokenType: domain.TokenTypeStatic,
TokenID: tokenID,
PermissionID: perm.ID,
Scope: permScope,
CreatedBy: userID,
}
grants = append(grants, grant)
}
if len(grants) > 0 {
err = s.grantRepo.GrantPermissions(ctx, grants)
if err != nil {
s.logger.Error("Failed to grant permissions", zap.Error(err))
// Clean up the token if permission granting fails
s.tokenRepo.Delete(ctx, tokenID)
return nil, fmt.Errorf("failed to grant permissions: %w", err)
}
}
response := &domain.CreateStaticTokenResponse{
ID: tokenID,
Token: "static-token-placeholder-" + tokenID.String(),
Permissions: req.Permissions,
Token: tokenInfo.Token, // Return the actual token only once
Permissions: validPermissions,
CreatedAt: now,
}
s.logger.Info("Static token created successfully", zap.String("token_id", tokenID.String()))
s.logger.Info("Static token created successfully",
zap.String("token_id", tokenID.String()),
zap.String("app_id", app.AppID),
zap.Strings("permissions", validPermissions))
return response, nil
}

View File

@ -267,10 +267,12 @@ test_token_endpoints() {
-H "X-User-Email: $USER_EMAIL" \
-H "X-User-ID: $USER_ID" \
-d '{
"name": "Test Static Token for Deletion",
"description": "A test static token for deletion test",
"permissions": ["read"],
"expires_at": "2025-12-31T23:59:59Z"
"owner": {
"type": "individual",
"name": "Test Token Owner",
"owner": "test-token@example.com"
},
"permissions": ["repo.read", "repo.write"]
}' 2>/dev/null || echo -e "\n000")
local token_status_code=$(echo "$token_response" | tail -n1)
@ -282,10 +284,12 @@ test_token_endpoints() {
-H "X-User-Email: $USER_EMAIL" \
-H "X-User-ID: $USER_ID" \
-d '{
"name": "Test Static Token",
"description": "A test static token",
"permissions": ["read"],
"expires_at": "2025-12-31T23:59:59Z"
"owner": {
"type": "individual",
"name": "Test Token Owner",
"owner": "test-token@example.com"
},
"permissions": ["repo.read", "repo.write"]
}'
# Extract token_id from the first response for deletion test

View File

@ -96,7 +96,7 @@ func (suite *IntegrationTestSuite) setupServer() {
// Initialize services
appService := services.NewApplicationService(appRepo, logger)
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger)
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, suite.cfg.GetString("INTERNAL_HMAC_KEY"), logger)
authService := services.NewAuthenticationService(suite.cfg, logger)
// Initialize handlers

View File

@ -460,14 +460,14 @@ func (m *MockPermissionRepository) ValidatePermissionScopes(ctx context.Context,
m.mu.RLock()
defer m.mu.RUnlock()
var invalid []string
var valid []string
for _, scope := range scopes {
if _, exists := m.scopeIndex[scope]; !exists {
invalid = append(invalid, scope)
if _, exists := m.scopeIndex[scope]; exists {
valid = append(valid, scope)
}
}
return invalid, nil
return valid, nil
}
func (m *MockPermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {