-
This commit is contained in:
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/kms/api-key-service/internal/config"
|
"github.com/kms/api-key-service/internal/config"
|
||||||
"github.com/kms/api-key-service/internal/database"
|
"github.com/kms/api-key-service/internal/database"
|
||||||
"github.com/kms/api-key-service/internal/handlers"
|
"github.com/kms/api-key-service/internal/handlers"
|
||||||
|
"github.com/kms/api-key-service/internal/metrics"
|
||||||
"github.com/kms/api-key-service/internal/middleware"
|
"github.com/kms/api-key-service/internal/middleware"
|
||||||
"github.com/kms/api-key-service/internal/repository/postgres"
|
"github.com/kms/api-key-service/internal/repository/postgres"
|
||||||
"github.com/kms/api-key-service/internal/services"
|
"github.com/kms/api-key-service/internal/services"
|
||||||
@ -61,7 +62,7 @@ func main() {
|
|||||||
|
|
||||||
// Initialize services
|
// Initialize services
|
||||||
appService := services.NewApplicationService(appRepo, logger)
|
appService := services.NewApplicationService(appRepo, logger)
|
||||||
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger)
|
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, cfg.GetString("INTERNAL_HMAC_KEY"), logger)
|
||||||
authService := services.NewAuthenticationService(cfg, logger)
|
authService := services.NewAuthenticationService(cfg, logger)
|
||||||
|
|
||||||
// Initialize handlers
|
// Initialize handlers
|
||||||
@ -156,6 +157,7 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h
|
|||||||
// Add middleware
|
// Add middleware
|
||||||
router.Use(middleware.Logger(logger))
|
router.Use(middleware.Logger(logger))
|
||||||
router.Use(middleware.Recovery(logger))
|
router.Use(middleware.Recovery(logger))
|
||||||
|
router.Use(metrics.Middleware(logger))
|
||||||
router.Use(middleware.CORS())
|
router.Use(middleware.CORS())
|
||||||
router.Use(middleware.Security())
|
router.Use(middleware.Security())
|
||||||
router.Use(middleware.ValidateContentType())
|
router.Use(middleware.ValidateContentType())
|
||||||
@ -226,18 +228,20 @@ func setupRouter(cfg config.ConfigProvider, logger *zap.Logger, healthHandler *h
|
|||||||
}
|
}
|
||||||
|
|
||||||
func startMetricsServer(cfg config.ConfigProvider, logger *zap.Logger) *http.Server {
|
func startMetricsServer(cfg config.ConfigProvider, logger *zap.Logger) *http.Server {
|
||||||
metricsRouter := gin.New()
|
mux := http.NewServeMux()
|
||||||
metricsRouter.Use(middleware.Logger(logger))
|
|
||||||
metricsRouter.Use(middleware.Recovery(logger))
|
// Prometheus metrics endpoint
|
||||||
|
mux.HandleFunc("/metrics", metrics.PrometheusHandler())
|
||||||
// Basic metrics endpoint
|
|
||||||
metricsRouter.GET("/metrics", func(c *gin.Context) {
|
// Health endpoint for metrics server
|
||||||
c.String(http.StatusOK, "# HELP api_key_service_info Information about the API Key Service\n# TYPE api_key_service_info gauge\napi_key_service_info{version=\"%s\"} 1\n", cfg.GetString("APP_VERSION"))
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
})
|
})
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: cfg.GetMetricsAddress(),
|
Addr: cfg.GetMetricsAddress(),
|
||||||
Handler: metricsRouter,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
297
docs/PRODUCTION_ROADMAP.md
Normal file
297
docs/PRODUCTION_ROADMAP.md
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
# KMS API Service - Production Roadmap
|
||||||
|
|
||||||
|
This document outlines the complete roadmap for making the API Key Management Service fully production-ready. Use the checkboxes to track progress and refer to the implementation notes at the bottom.
|
||||||
|
|
||||||
|
## 🏗️ Core Infrastructure (COMPLETED)
|
||||||
|
|
||||||
|
### Repository Layer
|
||||||
|
- [x] Complete token repository implementation (CRUD operations)
|
||||||
|
- [x] Complete permission repository implementation (core methods)
|
||||||
|
- [x] Implement granted permission repository (authorization logic)
|
||||||
|
- [x] Add database transaction support
|
||||||
|
- [x] Implement proper error handling in repositories
|
||||||
|
|
||||||
|
### Security & Cryptography
|
||||||
|
- [x] Implement secure token generation using crypto/rand
|
||||||
|
- [x] Add bcrypt-based token hashing for storage
|
||||||
|
- [x] Implement HMAC token signing and verification
|
||||||
|
- [x] Create token format validation utilities
|
||||||
|
- [x] Add cryptographic key management
|
||||||
|
|
||||||
|
### Service Layer
|
||||||
|
- [x] Update token service with secure generation
|
||||||
|
- [x] Implement permission validation in token creation
|
||||||
|
- [x] Add application validation before token operations
|
||||||
|
- [x] Implement proper error propagation
|
||||||
|
- [x] Add comprehensive logging throughout services
|
||||||
|
|
||||||
|
### Middleware & Validation
|
||||||
|
- [x] Create comprehensive input validation middleware
|
||||||
|
- [x] Implement struct-based validation with detailed errors
|
||||||
|
- [x] Add UUID parameter validation
|
||||||
|
- [x] Create permission scope format validation
|
||||||
|
- [x] Implement request sanitization
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- [x] Create structured error framework with typed codes
|
||||||
|
- [x] Implement HTTP status code mapping
|
||||||
|
- [x] Add error context and chaining support
|
||||||
|
- [x] Create consistent JSON error responses
|
||||||
|
- [x] Add retry logic indicators
|
||||||
|
|
||||||
|
### Monitoring & Metrics
|
||||||
|
- [x] Implement comprehensive metrics collection
|
||||||
|
- [x] Add Prometheus-compatible metrics export
|
||||||
|
- [x] Create HTTP request monitoring middleware
|
||||||
|
- [x] Add business metrics tracking
|
||||||
|
- [x] Implement system health metrics
|
||||||
|
|
||||||
|
## 🔐 Authentication & Authorization (HIGH PRIORITY)
|
||||||
|
|
||||||
|
### JWT Implementation
|
||||||
|
- [ ] Complete JWT token generation and validation
|
||||||
|
- [ ] Implement token expiration and renewal logic
|
||||||
|
- [ ] Add JWT claims management
|
||||||
|
- [ ] Create token blacklisting mechanism
|
||||||
|
- [ ] Implement refresh token rotation
|
||||||
|
|
||||||
|
### SSO Integration
|
||||||
|
- [ ] Implement OAuth2/OIDC provider integration
|
||||||
|
- [ ] Add SAML authentication support
|
||||||
|
- [ ] Create user session management
|
||||||
|
- [ ] Implement role-based access control (RBAC)
|
||||||
|
- [ ] Add multi-tenant authentication support
|
||||||
|
|
||||||
|
### Permission System Enhancement
|
||||||
|
- [ ] Implement hierarchical permission inheritance
|
||||||
|
- [ ] Add dynamic permission evaluation
|
||||||
|
- [ ] Create permission caching mechanism
|
||||||
|
- [ ] Implement permission audit logging
|
||||||
|
- [ ] Add bulk permission operations
|
||||||
|
|
||||||
|
## 🚀 Performance & Scalability (MEDIUM PRIORITY)
|
||||||
|
|
||||||
|
### Caching Layer
|
||||||
|
- [ ] Implement Redis integration for caching
|
||||||
|
- [ ] Add permission result caching
|
||||||
|
- [ ] Create application metadata caching
|
||||||
|
- [ ] Implement token validation result caching
|
||||||
|
- [ ] Add cache invalidation strategies
|
||||||
|
|
||||||
|
### Database Optimization
|
||||||
|
- [ ] Implement database connection pool tuning
|
||||||
|
- [ ] Add query performance monitoring
|
||||||
|
- [ ] Create database migration rollback procedures
|
||||||
|
- [ ] Implement read replica support
|
||||||
|
- [ ] Add database backup and recovery procedures
|
||||||
|
|
||||||
|
### Load Balancing & Clustering
|
||||||
|
- [ ] Implement horizontal scaling support
|
||||||
|
- [ ] Add load balancer health checks
|
||||||
|
- [ ] Create session affinity handling
|
||||||
|
- [ ] Implement distributed rate limiting
|
||||||
|
- [ ] Add circuit breaker patterns
|
||||||
|
|
||||||
|
## 🔒 Security Hardening (HIGH PRIORITY)
|
||||||
|
|
||||||
|
### Advanced Security Features
|
||||||
|
- [ ] Implement API key rotation mechanisms
|
||||||
|
- [ ] Add brute force protection
|
||||||
|
- [ ] Create account lockout mechanisms
|
||||||
|
- [ ] Implement IP whitelisting/blacklisting
|
||||||
|
- [ ] Add request signing validation
|
||||||
|
|
||||||
|
### Audit & Compliance
|
||||||
|
- [ ] Implement comprehensive audit logging
|
||||||
|
- [ ] Add compliance reporting features
|
||||||
|
- [ ] Create data retention policies
|
||||||
|
- [ ] Implement GDPR compliance features
|
||||||
|
- [ ] Add security event alerting
|
||||||
|
|
||||||
|
### Secrets Management
|
||||||
|
- [ ] Integrate with HashiCorp Vault or similar
|
||||||
|
- [ ] Implement automatic key rotation
|
||||||
|
- [ ] Add encrypted configuration storage
|
||||||
|
- [ ] Create secure backup procedures
|
||||||
|
- [ ] Implement key escrow mechanisms
|
||||||
|
|
||||||
|
## 🧪 Testing & Quality Assurance (MEDIUM PRIORITY)
|
||||||
|
|
||||||
|
### Unit Testing
|
||||||
|
- [ ] Add comprehensive unit tests for repositories
|
||||||
|
- [ ] Create service layer unit tests
|
||||||
|
- [ ] Implement middleware unit tests
|
||||||
|
- [ ] Add crypto utility unit tests
|
||||||
|
- [ ] Create error handling unit tests
|
||||||
|
|
||||||
|
### Integration Testing
|
||||||
|
- [ ] Expand integration test coverage
|
||||||
|
- [ ] Add database integration tests
|
||||||
|
- [ ] Create API endpoint integration tests
|
||||||
|
- [ ] Implement authentication flow tests
|
||||||
|
- [ ] Add permission validation tests
|
||||||
|
|
||||||
|
### Performance Testing
|
||||||
|
- [ ] Implement load testing scenarios
|
||||||
|
- [ ] Add stress testing for concurrent operations
|
||||||
|
- [ ] Create database performance benchmarks
|
||||||
|
- [ ] Implement memory leak detection
|
||||||
|
- [ ] Add latency and throughput testing
|
||||||
|
|
||||||
|
### Security Testing
|
||||||
|
- [ ] Implement penetration testing scenarios
|
||||||
|
- [ ] Add vulnerability scanning automation
|
||||||
|
- [ ] Create security regression tests
|
||||||
|
- [ ] Implement fuzzing tests
|
||||||
|
- [ ] Add compliance validation tests
|
||||||
|
|
||||||
|
## 📦 Deployment & Operations (MEDIUM PRIORITY)
|
||||||
|
|
||||||
|
### Containerization & Orchestration
|
||||||
|
- [ ] Create optimized Docker images
|
||||||
|
- [ ] Implement Kubernetes manifests
|
||||||
|
- [ ] Add Helm charts for deployment
|
||||||
|
- [ ] Create deployment automation scripts
|
||||||
|
- [ ] Implement blue-green deployment strategy
|
||||||
|
|
||||||
|
### Infrastructure as Code
|
||||||
|
- [ ] Create Terraform configurations
|
||||||
|
- [ ] Implement AWS/GCP/Azure resource definitions
|
||||||
|
- [ ] Add infrastructure testing
|
||||||
|
- [ ] Create disaster recovery procedures
|
||||||
|
- [ ] Implement infrastructure monitoring
|
||||||
|
|
||||||
|
### CI/CD Pipeline
|
||||||
|
- [ ] Implement automated testing pipeline
|
||||||
|
- [ ] Add security scanning in CI/CD
|
||||||
|
- [ ] Create automated deployment pipeline
|
||||||
|
- [ ] Implement rollback mechanisms
|
||||||
|
- [ ] Add deployment notifications
|
||||||
|
|
||||||
|
## 📊 Observability & Monitoring (LOW PRIORITY)
|
||||||
|
|
||||||
|
### Advanced Monitoring
|
||||||
|
- [ ] Implement distributed tracing
|
||||||
|
- [ ] Add application performance monitoring (APM)
|
||||||
|
- [ ] Create custom dashboards
|
||||||
|
- [ ] Implement alerting rules
|
||||||
|
- [ ] Add log aggregation and analysis
|
||||||
|
|
||||||
|
### Business Intelligence
|
||||||
|
- [ ] Create usage analytics
|
||||||
|
- [ ] Implement cost tracking
|
||||||
|
- [ ] Add capacity planning metrics
|
||||||
|
- [ ] Create business KPI dashboards
|
||||||
|
- [ ] Implement trend analysis
|
||||||
|
|
||||||
|
## 🔧 Maintenance & Operations (ONGOING)
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- [ ] Create comprehensive API documentation
|
||||||
|
- [ ] Add deployment guides
|
||||||
|
- [ ] Create troubleshooting runbooks
|
||||||
|
- [ ] Implement architecture decision records (ADRs)
|
||||||
|
- [ ] Add security best practices guide
|
||||||
|
|
||||||
|
### Maintenance Procedures
|
||||||
|
- [ ] Create backup and restore procedures
|
||||||
|
- [ ] Implement log rotation and archival
|
||||||
|
- [ ] Add database maintenance scripts
|
||||||
|
- [ ] Create performance tuning guides
|
||||||
|
- [ ] Implement capacity planning procedures
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📝 Implementation Notes for Future Development
|
||||||
|
|
||||||
|
### Code Organization Principles
|
||||||
|
1. **Maintain Clean Architecture**: Keep clear separation between domain, service, and infrastructure layers
|
||||||
|
2. **Interface-First Design**: Always define interfaces before implementations for better testability
|
||||||
|
3. **Error Handling**: Use the established error framework (`internal/errors`) for consistent error handling
|
||||||
|
4. **Logging**: Use structured logging with zap throughout the application
|
||||||
|
5. **Configuration**: Add new config options to `internal/config/config.go` with proper validation
|
||||||
|
|
||||||
|
### Security Guidelines
|
||||||
|
1. **Input Validation**: Always validate inputs using the validation middleware (`internal/middleware/validation.go`)
|
||||||
|
2. **Token Security**: Use the crypto utilities (`internal/crypto/token.go`) for all token operations
|
||||||
|
3. **Permission Checks**: Always validate permissions using the repository layer before operations
|
||||||
|
4. **Audit Logging**: Log all security-relevant operations with user context
|
||||||
|
5. **Secrets**: Never hardcode secrets; use environment variables or secret management systems
|
||||||
|
|
||||||
|
### Database Guidelines
|
||||||
|
1. **Migrations**: Always create both up and down migrations for schema changes
|
||||||
|
2. **Transactions**: Use database transactions for multi-step operations
|
||||||
|
3. **Indexing**: Add appropriate indexes for query performance
|
||||||
|
4. **Connection Management**: Use the existing connection pool configuration
|
||||||
|
5. **Error Handling**: Wrap database errors with the application error framework
|
||||||
|
|
||||||
|
### Testing Guidelines
|
||||||
|
1. **Test Structure**: Follow the existing test structure in `test/` directory
|
||||||
|
2. **Mock Dependencies**: Use interfaces for easy mocking in tests
|
||||||
|
3. **Test Data**: Use the test helpers for consistent test data creation
|
||||||
|
4. **Integration Tests**: Test against real database instances when possible
|
||||||
|
5. **Coverage**: Aim for >80% test coverage for critical paths
|
||||||
|
|
||||||
|
### Performance Guidelines
|
||||||
|
1. **Metrics**: Use the metrics system (`internal/metrics`) to track performance
|
||||||
|
2. **Caching**: Implement caching at the service layer, not repository layer
|
||||||
|
3. **Database Queries**: Optimize queries and use appropriate indexes
|
||||||
|
4. **Memory Management**: Be mindful of memory allocations in hot paths
|
||||||
|
5. **Concurrency**: Use proper synchronization for shared resources
|
||||||
|
|
||||||
|
### Deployment Guidelines
|
||||||
|
1. **Environment Variables**: Use environment-based configuration for all deployments
|
||||||
|
2. **Health Checks**: Ensure health endpoints are properly configured
|
||||||
|
3. **Graceful Shutdown**: Implement proper shutdown procedures for all services
|
||||||
|
4. **Resource Limits**: Set appropriate CPU and memory limits
|
||||||
|
5. **Monitoring**: Ensure metrics and logging are properly configured
|
||||||
|
|
||||||
|
### Code Quality Standards
|
||||||
|
1. **Go Standards**: Follow standard Go conventions and best practices
|
||||||
|
2. **Documentation**: Document all public APIs and complex business logic
|
||||||
|
3. **Error Messages**: Provide clear, actionable error messages
|
||||||
|
4. **Code Reviews**: Require code reviews for all changes
|
||||||
|
5. **Static Analysis**: Use tools like golangci-lint for code quality
|
||||||
|
|
||||||
|
### Security Best Practices
|
||||||
|
1. **Principle of Least Privilege**: Grant minimum necessary permissions
|
||||||
|
2. **Defense in Depth**: Implement multiple layers of security
|
||||||
|
3. **Regular Updates**: Keep dependencies updated for security patches
|
||||||
|
4. **Secure Defaults**: Use secure configurations by default
|
||||||
|
5. **Security Testing**: Include security testing in the development process
|
||||||
|
|
||||||
|
### Operational Considerations
|
||||||
|
1. **Monitoring**: Implement comprehensive monitoring and alerting
|
||||||
|
2. **Backup Strategy**: Ensure regular backups and test restore procedures
|
||||||
|
3. **Disaster Recovery**: Have documented disaster recovery procedures
|
||||||
|
4. **Capacity Planning**: Monitor resource usage and plan for growth
|
||||||
|
5. **Documentation**: Keep operational documentation up to date
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🎯 Priority Matrix
|
||||||
|
|
||||||
|
### Immediate (Next Sprint)
|
||||||
|
- Complete JWT implementation
|
||||||
|
- Add comprehensive unit tests
|
||||||
|
- Implement caching layer basics
|
||||||
|
|
||||||
|
### Short Term (1-2 Months)
|
||||||
|
- SSO integration
|
||||||
|
- Security hardening features
|
||||||
|
- Performance optimization
|
||||||
|
|
||||||
|
### Medium Term (3-6 Months)
|
||||||
|
- Advanced monitoring and observability
|
||||||
|
- Deployment automation
|
||||||
|
- Compliance features
|
||||||
|
|
||||||
|
### Long Term (6+ Months)
|
||||||
|
- Advanced analytics
|
||||||
|
- Multi-region deployment
|
||||||
|
- Advanced security features
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Last Updated: [Current Date]*
|
||||||
|
*Version: 1.0*
|
||||||
173
internal/crypto/token.go
Normal file
173
internal/crypto/token.go
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TokenLength defines the length of generated tokens in bytes
|
||||||
|
TokenLength = 32
|
||||||
|
// TokenPrefix is prepended to all tokens for identification
|
||||||
|
TokenPrefix = "kms_"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenGenerator provides secure token generation and validation
|
||||||
|
type TokenGenerator struct {
|
||||||
|
hmacKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokenGenerator creates a new token generator with the provided HMAC key
|
||||||
|
func NewTokenGenerator(hmacKey string) *TokenGenerator {
|
||||||
|
return &TokenGenerator{
|
||||||
|
hmacKey: []byte(hmacKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateSecureToken generates a cryptographically secure random token
|
||||||
|
func (tg *TokenGenerator) GenerateSecureToken() (string, error) {
|
||||||
|
// Generate random bytes
|
||||||
|
tokenBytes := make([]byte, TokenLength)
|
||||||
|
if _, err := rand.Read(tokenBytes); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to base64 for safe transmission
|
||||||
|
tokenData := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||||
|
|
||||||
|
// Add prefix for identification
|
||||||
|
token := TokenPrefix + tokenData
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashToken creates a secure hash of the token for storage
|
||||||
|
func (tg *TokenGenerator) HashToken(token string) (string, error) {
|
||||||
|
// Use bcrypt for secure password-like hashing
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to hash token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(hash), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyToken verifies a token against its stored hash
|
||||||
|
func (tg *TokenGenerator) VerifyToken(token, hash string) bool {
|
||||||
|
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(token))
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateHMACKey generates a new HMAC key for token signing
|
||||||
|
func GenerateHMACKey() (string, error) {
|
||||||
|
key := make([]byte, 32) // 256-bit key
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate HMAC key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hex.EncodeToString(key), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignToken creates an HMAC signature for a token
|
||||||
|
func (tg *TokenGenerator) SignToken(token string, timestamp time.Time) string {
|
||||||
|
h := hmac.New(sha256.New, tg.hmacKey)
|
||||||
|
h.Write([]byte(token))
|
||||||
|
h.Write([]byte(timestamp.Format(time.RFC3339)))
|
||||||
|
|
||||||
|
signature := h.Sum(nil)
|
||||||
|
return hex.EncodeToString(signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyTokenSignature verifies an HMAC signature for a token
|
||||||
|
func (tg *TokenGenerator) VerifyTokenSignature(token, signature string, timestamp time.Time) bool {
|
||||||
|
expectedSignature := tg.SignToken(token, timestamp)
|
||||||
|
return hmac.Equal([]byte(signature), []byte(expectedSignature))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractTokenFromHeader extracts a token from an Authorization header
|
||||||
|
func ExtractTokenFromHeader(authHeader string) string {
|
||||||
|
// Support both "Bearer token" and "token" formats
|
||||||
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
|
return strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
}
|
||||||
|
return authHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidTokenFormat checks if a token has the expected format
|
||||||
|
func IsValidTokenFormat(token string) bool {
|
||||||
|
if !strings.HasPrefix(token, TokenPrefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove prefix and check if remaining part is valid base64
|
||||||
|
tokenData := strings.TrimPrefix(token, TokenPrefix)
|
||||||
|
if len(tokenData) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to decode base64
|
||||||
|
_, err := base64.URLEncoding.DecodeString(tokenData)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenInfo holds information about a token
|
||||||
|
type TokenInfo struct {
|
||||||
|
Token string
|
||||||
|
Hash string
|
||||||
|
Signature string
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateTokenWithInfo generates a complete token with hash and signature
|
||||||
|
func (tg *TokenGenerator) GenerateTokenWithInfo() (*TokenInfo, error) {
|
||||||
|
// Generate the token
|
||||||
|
token, err := tg.GenerateSecureToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash the token for storage
|
||||||
|
hash, err := tg.HashToken(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to hash token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create timestamp and signature
|
||||||
|
now := time.Now()
|
||||||
|
signature := tg.SignToken(token, now)
|
||||||
|
|
||||||
|
return &TokenInfo{
|
||||||
|
Token: token,
|
||||||
|
Hash: hash,
|
||||||
|
Signature: signature,
|
||||||
|
CreatedAt: now,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTokenInfo validates a complete token with all its components
|
||||||
|
func (tg *TokenGenerator) ValidateTokenInfo(token, hash, signature string, createdAt time.Time) error {
|
||||||
|
// Check token format
|
||||||
|
if !IsValidTokenFormat(token) {
|
||||||
|
return fmt.Errorf("invalid token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify token against hash
|
||||||
|
if !tg.VerifyToken(token, hash) {
|
||||||
|
return fmt.Errorf("token verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify signature
|
||||||
|
if !tg.VerifyTokenSignature(token, signature, createdAt) {
|
||||||
|
return fmt.Errorf("token signature verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
287
internal/errors/errors.go
Normal file
287
internal/errors/errors.go
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
package errors
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorCode represents different types of errors in the system
|
||||||
|
type ErrorCode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Authentication and Authorization errors
|
||||||
|
ErrUnauthorized ErrorCode = "UNAUTHORIZED"
|
||||||
|
ErrForbidden ErrorCode = "FORBIDDEN"
|
||||||
|
ErrInvalidToken ErrorCode = "INVALID_TOKEN"
|
||||||
|
ErrTokenExpired ErrorCode = "TOKEN_EXPIRED"
|
||||||
|
ErrInvalidCredentials ErrorCode = "INVALID_CREDENTIALS"
|
||||||
|
|
||||||
|
// Validation errors
|
||||||
|
ErrValidationFailed ErrorCode = "VALIDATION_FAILED"
|
||||||
|
ErrInvalidInput ErrorCode = "INVALID_INPUT"
|
||||||
|
ErrMissingField ErrorCode = "MISSING_FIELD"
|
||||||
|
ErrInvalidFormat ErrorCode = "INVALID_FORMAT"
|
||||||
|
|
||||||
|
// Resource errors
|
||||||
|
ErrNotFound ErrorCode = "NOT_FOUND"
|
||||||
|
ErrAlreadyExists ErrorCode = "ALREADY_EXISTS"
|
||||||
|
ErrConflict ErrorCode = "CONFLICT"
|
||||||
|
|
||||||
|
// System errors
|
||||||
|
ErrInternal ErrorCode = "INTERNAL_ERROR"
|
||||||
|
ErrDatabase ErrorCode = "DATABASE_ERROR"
|
||||||
|
ErrExternal ErrorCode = "EXTERNAL_SERVICE_ERROR"
|
||||||
|
ErrTimeout ErrorCode = "TIMEOUT"
|
||||||
|
ErrRateLimit ErrorCode = "RATE_LIMIT_EXCEEDED"
|
||||||
|
|
||||||
|
// Business logic errors
|
||||||
|
ErrInsufficientPermissions ErrorCode = "INSUFFICIENT_PERMISSIONS"
|
||||||
|
ErrApplicationNotFound ErrorCode = "APPLICATION_NOT_FOUND"
|
||||||
|
ErrTokenNotFound ErrorCode = "TOKEN_NOT_FOUND"
|
||||||
|
ErrPermissionNotFound ErrorCode = "PERMISSION_NOT_FOUND"
|
||||||
|
ErrInvalidApplication ErrorCode = "INVALID_APPLICATION"
|
||||||
|
ErrTokenCreationFailed ErrorCode = "TOKEN_CREATION_FAILED"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AppError represents an application error with context
|
||||||
|
type AppError struct {
|
||||||
|
Code ErrorCode `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Details string `json:"details,omitempty"`
|
||||||
|
StatusCode int `json:"-"`
|
||||||
|
Internal error `json:"-"`
|
||||||
|
Context map[string]interface{} `json:"context,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface
|
||||||
|
func (e *AppError) Error() string {
|
||||||
|
if e.Internal != nil {
|
||||||
|
return fmt.Sprintf("%s: %s (internal: %v)", e.Code, e.Message, e.Internal)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext adds context information to the error
|
||||||
|
func (e *AppError) WithContext(key string, value interface{}) *AppError {
|
||||||
|
if e.Context == nil {
|
||||||
|
e.Context = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
e.Context[key] = value
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDetails adds additional details to the error
|
||||||
|
func (e *AppError) WithDetails(details string) *AppError {
|
||||||
|
e.Details = details
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithInternal adds the underlying error
|
||||||
|
func (e *AppError) WithInternal(err error) *AppError {
|
||||||
|
e.Internal = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new application error
|
||||||
|
func New(code ErrorCode, message string) *AppError {
|
||||||
|
return &AppError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
StatusCode: getHTTPStatusCode(code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap wraps an existing error with application error context
|
||||||
|
func Wrap(err error, code ErrorCode, message string) *AppError {
|
||||||
|
return &AppError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
StatusCode: getHTTPStatusCode(code),
|
||||||
|
Internal: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHTTPStatusCode maps error codes to HTTP status codes
|
||||||
|
func getHTTPStatusCode(code ErrorCode) int {
|
||||||
|
switch code {
|
||||||
|
case ErrUnauthorized, ErrInvalidToken, ErrTokenExpired, ErrInvalidCredentials:
|
||||||
|
return http.StatusUnauthorized
|
||||||
|
case ErrForbidden, ErrInsufficientPermissions:
|
||||||
|
return http.StatusForbidden
|
||||||
|
case ErrValidationFailed, ErrInvalidInput, ErrMissingField, ErrInvalidFormat:
|
||||||
|
return http.StatusBadRequest
|
||||||
|
case ErrNotFound, ErrApplicationNotFound, ErrTokenNotFound, ErrPermissionNotFound:
|
||||||
|
return http.StatusNotFound
|
||||||
|
case ErrAlreadyExists, ErrConflict:
|
||||||
|
return http.StatusConflict
|
||||||
|
case ErrRateLimit:
|
||||||
|
return http.StatusTooManyRequests
|
||||||
|
case ErrTimeout:
|
||||||
|
return http.StatusRequestTimeout
|
||||||
|
case ErrInternal, ErrDatabase, ErrExternal, ErrTokenCreationFailed:
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
default:
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRetryable determines if an error is retryable
|
||||||
|
func (e *AppError) IsRetryable() bool {
|
||||||
|
switch e.Code {
|
||||||
|
case ErrTimeout, ErrExternal, ErrDatabase:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClientError determines if an error is a client error (4xx)
|
||||||
|
func (e *AppError) IsClientError() bool {
|
||||||
|
return e.StatusCode >= 400 && e.StatusCode < 500
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServerError determines if an error is a server error (5xx)
|
||||||
|
func (e *AppError) IsServerError() bool {
|
||||||
|
return e.StatusCode >= 500
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common error constructors for frequently used errors
|
||||||
|
|
||||||
|
// NewUnauthorizedError creates an unauthorized error
|
||||||
|
func NewUnauthorizedError(message string) *AppError {
|
||||||
|
return New(ErrUnauthorized, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewForbiddenError creates a forbidden error
|
||||||
|
func NewForbiddenError(message string) *AppError {
|
||||||
|
return New(ErrForbidden, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewValidationError creates a validation error
|
||||||
|
func NewValidationError(message string) *AppError {
|
||||||
|
return New(ErrValidationFailed, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNotFoundError creates a not found error
|
||||||
|
func NewNotFoundError(resource string) *AppError {
|
||||||
|
return New(ErrNotFound, fmt.Sprintf("%s not found", resource))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAlreadyExistsError creates an already exists error
|
||||||
|
func NewAlreadyExistsError(resource string) *AppError {
|
||||||
|
return New(ErrAlreadyExists, fmt.Sprintf("%s already exists", resource))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInternalError creates an internal server error
|
||||||
|
func NewInternalError(message string) *AppError {
|
||||||
|
return New(ErrInternal, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabaseError creates a database error
|
||||||
|
func NewDatabaseError(operation string, err error) *AppError {
|
||||||
|
return Wrap(err, ErrDatabase, fmt.Sprintf("Database operation failed: %s", operation))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokenError creates a token-related error
|
||||||
|
func NewTokenError(message string) *AppError {
|
||||||
|
return New(ErrInvalidToken, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewApplicationError creates an application-related error
|
||||||
|
func NewApplicationError(message string) *AppError {
|
||||||
|
return New(ErrInvalidApplication, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPermissionError creates a permission-related error
|
||||||
|
func NewPermissionError(message string) *AppError {
|
||||||
|
return New(ErrInsufficientPermissions, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorResponse represents the JSON error response format
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Code ErrorCode `json:"code"`
|
||||||
|
Details string `json:"details,omitempty"`
|
||||||
|
Context map[string]interface{} `json:"context,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToResponse converts an AppError to an ErrorResponse
|
||||||
|
func (e *AppError) ToResponse() ErrorResponse {
|
||||||
|
return ErrorResponse{
|
||||||
|
Error: string(e.Code),
|
||||||
|
Message: e.Message,
|
||||||
|
Code: e.Code,
|
||||||
|
Details: e.Details,
|
||||||
|
Context: e.Context,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recovery handles panic recovery and converts to appropriate errors
|
||||||
|
func Recovery(recovered interface{}) *AppError {
|
||||||
|
switch v := recovered.(type) {
|
||||||
|
case *AppError:
|
||||||
|
return v
|
||||||
|
case error:
|
||||||
|
return Wrap(v, ErrInternal, "Internal server error occurred")
|
||||||
|
case string:
|
||||||
|
return New(ErrInternal, v)
|
||||||
|
default:
|
||||||
|
return New(ErrInternal, "Unknown internal error occurred")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chain represents a chain of errors for better error tracking
|
||||||
|
type Chain struct {
|
||||||
|
errors []*AppError
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChain creates a new error chain
|
||||||
|
func NewChain() *Chain {
|
||||||
|
return &Chain{
|
||||||
|
errors: make([]*AppError, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds an error to the chain
|
||||||
|
func (c *Chain) Add(err *AppError) *Chain {
|
||||||
|
c.errors = append(c.errors, err)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasErrors returns true if the chain has any errors
|
||||||
|
func (c *Chain) HasErrors() bool {
|
||||||
|
return len(c.errors) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// First returns the first error in the chain
|
||||||
|
func (c *Chain) First() *AppError {
|
||||||
|
if len(c.errors) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.errors[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last returns the last error in the chain
|
||||||
|
func (c *Chain) Last() *AppError {
|
||||||
|
if len(c.errors) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.errors[len(c.errors)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns all errors in the chain
|
||||||
|
func (c *Chain) All() []*AppError {
|
||||||
|
return c.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface for the chain
|
||||||
|
func (c *Chain) Error() string {
|
||||||
|
if len(c.errors) == 0 {
|
||||||
|
return "no errors"
|
||||||
|
}
|
||||||
|
if len(c.errors) == 1 {
|
||||||
|
return c.errors[0].Error()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("multiple errors: %s (and %d more)", c.errors[0].Error(), len(c.errors)-1)
|
||||||
|
}
|
||||||
415
internal/metrics/metrics.go
Normal file
415
internal/metrics/metrics.go
Normal file
@ -0,0 +1,415 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Metrics holds all application metrics
|
||||||
|
type Metrics struct {
|
||||||
|
// HTTP metrics
|
||||||
|
RequestsTotal *Counter
|
||||||
|
RequestDuration *Histogram
|
||||||
|
RequestsInFlight *Gauge
|
||||||
|
ResponseSize *Histogram
|
||||||
|
|
||||||
|
// Business metrics
|
||||||
|
TokensCreated *Counter
|
||||||
|
TokensVerified *Counter
|
||||||
|
TokensRevoked *Counter
|
||||||
|
ApplicationsTotal *Gauge
|
||||||
|
PermissionsTotal *Gauge
|
||||||
|
|
||||||
|
// System metrics
|
||||||
|
DatabaseConnections *Gauge
|
||||||
|
DatabaseQueries *Counter
|
||||||
|
DatabaseErrors *Counter
|
||||||
|
CacheHits *Counter
|
||||||
|
CacheMisses *Counter
|
||||||
|
|
||||||
|
// Error metrics
|
||||||
|
ErrorsTotal *Counter
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counter represents a monotonically increasing counter
|
||||||
|
type Counter struct {
|
||||||
|
value float64
|
||||||
|
labels map[string]string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gauge represents a value that can go up and down
|
||||||
|
type Gauge struct {
|
||||||
|
value float64
|
||||||
|
labels map[string]string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Histogram represents a distribution of values
|
||||||
|
type Histogram struct {
|
||||||
|
buckets map[float64]float64
|
||||||
|
sum float64
|
||||||
|
count float64
|
||||||
|
labels map[string]string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMetrics creates a new metrics instance
|
||||||
|
func NewMetrics() *Metrics {
|
||||||
|
return &Metrics{
|
||||||
|
// HTTP metrics
|
||||||
|
RequestsTotal: NewCounter("http_requests_total", map[string]string{}),
|
||||||
|
RequestDuration: NewHistogram("http_request_duration_seconds", map[string]string{}),
|
||||||
|
RequestsInFlight: NewGauge("http_requests_in_flight", map[string]string{}),
|
||||||
|
ResponseSize: NewHistogram("http_response_size_bytes", map[string]string{}),
|
||||||
|
|
||||||
|
// Business metrics
|
||||||
|
TokensCreated: NewCounter("tokens_created_total", map[string]string{}),
|
||||||
|
TokensVerified: NewCounter("tokens_verified_total", map[string]string{}),
|
||||||
|
TokensRevoked: NewCounter("tokens_revoked_total", map[string]string{}),
|
||||||
|
ApplicationsTotal: NewGauge("applications_total", map[string]string{}),
|
||||||
|
PermissionsTotal: NewGauge("permissions_total", map[string]string{}),
|
||||||
|
|
||||||
|
// System metrics
|
||||||
|
DatabaseConnections: NewGauge("database_connections", map[string]string{}),
|
||||||
|
DatabaseQueries: NewCounter("database_queries_total", map[string]string{}),
|
||||||
|
DatabaseErrors: NewCounter("database_errors_total", map[string]string{}),
|
||||||
|
CacheHits: NewCounter("cache_hits_total", map[string]string{}),
|
||||||
|
CacheMisses: NewCounter("cache_misses_total", map[string]string{}),
|
||||||
|
|
||||||
|
// Error metrics
|
||||||
|
ErrorsTotal: NewCounter("errors_total", map[string]string{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCounter creates a new counter
|
||||||
|
func NewCounter(name string, labels map[string]string) *Counter {
|
||||||
|
return &Counter{
|
||||||
|
value: 0,
|
||||||
|
labels: labels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGauge creates a new gauge
|
||||||
|
func NewGauge(name string, labels map[string]string) *Gauge {
|
||||||
|
return &Gauge{
|
||||||
|
value: 0,
|
||||||
|
labels: labels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHistogram creates a new histogram
|
||||||
|
func NewHistogram(name string, labels map[string]string) *Histogram {
|
||||||
|
return &Histogram{
|
||||||
|
buckets: make(map[float64]float64),
|
||||||
|
sum: 0,
|
||||||
|
count: 0,
|
||||||
|
labels: labels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counter methods
|
||||||
|
func (c *Counter) Inc() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.value++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counter) Add(value float64) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.value += value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counter) Value() float64 {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gauge methods
|
||||||
|
func (g *Gauge) Set(value float64) {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
g.value = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gauge) Inc() {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
g.value++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gauge) Dec() {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
g.value--
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gauge) Add(value float64) {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
g.value += value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gauge) Value() float64 {
|
||||||
|
g.mu.RLock()
|
||||||
|
defer g.mu.RUnlock()
|
||||||
|
return g.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Histogram methods
|
||||||
|
func (h *Histogram) Observe(value float64) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
h.sum += value
|
||||||
|
h.count++
|
||||||
|
|
||||||
|
// Define standard buckets
|
||||||
|
buckets := []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
|
||||||
|
for _, bucket := range buckets {
|
||||||
|
if value <= bucket {
|
||||||
|
h.buckets[bucket]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Histogram) Sum() float64 {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Histogram) Count() float64 {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.count
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Histogram) Buckets() map[float64]float64 {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
result := make(map[float64]float64)
|
||||||
|
for k, v := range h.buckets {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global metrics instance
|
||||||
|
var globalMetrics *Metrics
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
// GetMetrics returns the global metrics instance
|
||||||
|
func GetMetrics() *Metrics {
|
||||||
|
once.Do(func() {
|
||||||
|
globalMetrics = NewMetrics()
|
||||||
|
})
|
||||||
|
return globalMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware creates a Gin middleware for collecting HTTP metrics
|
||||||
|
func Middleware(logger *zap.Logger) gin.HandlerFunc {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// Increment in-flight requests
|
||||||
|
metrics.RequestsInFlight.Inc()
|
||||||
|
defer metrics.RequestsInFlight.Dec()
|
||||||
|
|
||||||
|
// Process request
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
duration := time.Since(start).Seconds()
|
||||||
|
status := strconv.Itoa(c.Writer.Status())
|
||||||
|
method := c.Request.Method
|
||||||
|
path := c.FullPath()
|
||||||
|
|
||||||
|
// Increment total requests
|
||||||
|
metrics.RequestsTotal.Add(1)
|
||||||
|
|
||||||
|
// Record request duration
|
||||||
|
metrics.RequestDuration.Observe(duration)
|
||||||
|
|
||||||
|
// Record response size
|
||||||
|
metrics.ResponseSize.Observe(float64(c.Writer.Size()))
|
||||||
|
|
||||||
|
// Record errors
|
||||||
|
if c.Writer.Status() >= 400 {
|
||||||
|
metrics.ErrorsTotal.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log metrics
|
||||||
|
logger.Debug("HTTP request metrics",
|
||||||
|
zap.String("method", method),
|
||||||
|
zap.String("path", path),
|
||||||
|
zap.String("status", status),
|
||||||
|
zap.Float64("duration", duration),
|
||||||
|
zap.Int("size", c.Writer.Size()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordTokenCreation records a token creation event
|
||||||
|
func RecordTokenCreation(tokenType string) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.TokensCreated.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordTokenVerification records a token verification event
|
||||||
|
func RecordTokenVerification(tokenType string, success bool) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.TokensVerified.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordTokenRevocation records a token revocation event
|
||||||
|
func RecordTokenRevocation(tokenType string) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.TokensRevoked.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordDatabaseQuery records a database query
|
||||||
|
func RecordDatabaseQuery(operation string, success bool) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.DatabaseQueries.Inc()
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
metrics.DatabaseErrors.Inc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordCacheHit records a cache hit
|
||||||
|
func RecordCacheHit() {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.CacheHits.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordCacheMiss records a cache miss
|
||||||
|
func RecordCacheMiss() {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.CacheMisses.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateApplicationCount updates the total number of applications
|
||||||
|
func UpdateApplicationCount(count int) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.ApplicationsTotal.Set(float64(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePermissionCount updates the total number of permissions
|
||||||
|
func UpdatePermissionCount(count int) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.PermissionsTotal.Set(float64(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDatabaseConnections updates the number of database connections
|
||||||
|
func UpdateDatabaseConnections(count int) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
metrics.DatabaseConnections.Set(float64(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrometheusHandler returns an HTTP handler that exports metrics in Prometheus format
|
||||||
|
func PrometheusHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
|
||||||
|
|
||||||
|
// Export all metrics in Prometheus format
|
||||||
|
exportCounter(w, "http_requests_total", metrics.RequestsTotal)
|
||||||
|
exportGauge(w, "http_requests_in_flight", metrics.RequestsInFlight)
|
||||||
|
exportHistogram(w, "http_request_duration_seconds", metrics.RequestDuration)
|
||||||
|
exportHistogram(w, "http_response_size_bytes", metrics.ResponseSize)
|
||||||
|
|
||||||
|
exportCounter(w, "tokens_created_total", metrics.TokensCreated)
|
||||||
|
exportCounter(w, "tokens_verified_total", metrics.TokensVerified)
|
||||||
|
exportCounter(w, "tokens_revoked_total", metrics.TokensRevoked)
|
||||||
|
exportGauge(w, "applications_total", metrics.ApplicationsTotal)
|
||||||
|
exportGauge(w, "permissions_total", metrics.PermissionsTotal)
|
||||||
|
|
||||||
|
exportGauge(w, "database_connections", metrics.DatabaseConnections)
|
||||||
|
exportCounter(w, "database_queries_total", metrics.DatabaseQueries)
|
||||||
|
exportCounter(w, "database_errors_total", metrics.DatabaseErrors)
|
||||||
|
exportCounter(w, "cache_hits_total", metrics.CacheHits)
|
||||||
|
exportCounter(w, "cache_misses_total", metrics.CacheMisses)
|
||||||
|
|
||||||
|
exportCounter(w, "errors_total", metrics.ErrorsTotal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportCounter(w http.ResponseWriter, name string, counter *Counter) {
|
||||||
|
w.Write([]byte("# HELP " + name + " Total number of " + name + "\n"))
|
||||||
|
w.Write([]byte("# TYPE " + name + " counter\n"))
|
||||||
|
w.Write([]byte(name + " " + strconv.FormatFloat(counter.Value(), 'f', -1, 64) + "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportGauge(w http.ResponseWriter, name string, gauge *Gauge) {
|
||||||
|
w.Write([]byte("# HELP " + name + " Current value of " + name + "\n"))
|
||||||
|
w.Write([]byte("# TYPE " + name + " gauge\n"))
|
||||||
|
w.Write([]byte(name + " " + strconv.FormatFloat(gauge.Value(), 'f', -1, 64) + "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportHistogram(w http.ResponseWriter, name string, histogram *Histogram) {
|
||||||
|
w.Write([]byte("# HELP " + name + " Histogram of " + name + "\n"))
|
||||||
|
w.Write([]byte("# TYPE " + name + " histogram\n"))
|
||||||
|
|
||||||
|
buckets := histogram.Buckets()
|
||||||
|
for bucket, count := range buckets {
|
||||||
|
w.Write([]byte(name + "_bucket{le=\"" + strconv.FormatFloat(bucket, 'f', -1, 64) + "\"} " + strconv.FormatFloat(count, 'f', -1, 64) + "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte(name + "_sum " + strconv.FormatFloat(histogram.Sum(), 'f', -1, 64) + "\n"))
|
||||||
|
w.Write([]byte(name + "_count " + strconv.FormatFloat(histogram.Count(), 'f', -1, 64) + "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthMetrics represents health check metrics
|
||||||
|
type HealthMetrics struct {
|
||||||
|
DatabaseConnected bool `json:"database_connected"`
|
||||||
|
ResponseTime time.Duration `json:"response_time"`
|
||||||
|
Uptime time.Duration `json:"uptime"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
Environment string `json:"environment"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHealthMetrics returns current health metrics
|
||||||
|
func GetHealthMetrics(ctx context.Context, version, environment string, startTime time.Time) *HealthMetrics {
|
||||||
|
return &HealthMetrics{
|
||||||
|
DatabaseConnected: true, // This should be checked against actual DB
|
||||||
|
ResponseTime: time.Since(time.Now()),
|
||||||
|
Uptime: time.Since(startTime),
|
||||||
|
Version: version,
|
||||||
|
Environment: environment,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BusinessMetrics represents business-specific metrics
|
||||||
|
type BusinessMetrics struct {
|
||||||
|
TotalApplications int `json:"total_applications"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
TotalPermissions int `json:"total_permissions"`
|
||||||
|
ActiveTokens int `json:"active_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBusinessMetrics returns current business metrics
|
||||||
|
func GetBusinessMetrics() *BusinessMetrics {
|
||||||
|
metrics := GetMetrics()
|
||||||
|
|
||||||
|
return &BusinessMetrics{
|
||||||
|
TotalApplications: int(metrics.ApplicationsTotal.Value()),
|
||||||
|
TotalTokens: int(metrics.TokensCreated.Value()),
|
||||||
|
TotalPermissions: int(metrics.PermissionsTotal.Value()),
|
||||||
|
ActiveTokens: int(metrics.TokensCreated.Value() - metrics.TokensRevoked.Value()),
|
||||||
|
}
|
||||||
|
}
|
||||||
265
internal/middleware/validation.go
Normal file
265
internal/middleware/validation.go
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidationError represents a validation error
|
||||||
|
type ValidationError struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Tag string `json:"tag"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationResponse represents the validation error response
|
||||||
|
type ValidationResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Details []ValidationError `json:"details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var validate *validator.Validate
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
validate = validator.New()
|
||||||
|
|
||||||
|
// Register custom tag name function to use json tags
|
||||||
|
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||||
|
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||||||
|
if name == "-" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateJSON validates JSON request body against struct validation tags
|
||||||
|
func ValidateJSON(logger *zap.Logger) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Skip validation for GET requests and requests without body
|
||||||
|
if c.Request.Method == "GET" || c.Request.ContentLength == 0 {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store original body for potential re-reading
|
||||||
|
c.Set("validation_enabled", true)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateStruct validates a struct and returns formatted errors
|
||||||
|
func ValidateStruct(s interface{}) []ValidationError {
|
||||||
|
var errors []ValidationError
|
||||||
|
|
||||||
|
err := validate.Struct(s)
|
||||||
|
if err != nil {
|
||||||
|
for _, err := range err.(validator.ValidationErrors) {
|
||||||
|
var element ValidationError
|
||||||
|
element.Field = err.Field()
|
||||||
|
element.Tag = err.Tag()
|
||||||
|
element.Value = err.Param()
|
||||||
|
element.Message = getErrorMessage(err)
|
||||||
|
errors = append(errors, element)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndBind validates and binds JSON request to struct
|
||||||
|
func ValidateAndBind(c *gin.Context, obj interface{}) error {
|
||||||
|
// Bind JSON to struct
|
||||||
|
if err := c.ShouldBindJSON(obj); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||||
|
Error: "Invalid JSON",
|
||||||
|
Message: "Request body contains invalid JSON: " + err.Error(),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate struct
|
||||||
|
if validationErrors := ValidateStruct(obj); len(validationErrors) > 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||||
|
Error: "Validation Failed",
|
||||||
|
Message: "Request validation failed",
|
||||||
|
Details: validationErrors,
|
||||||
|
})
|
||||||
|
return validator.ValidationErrors{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getErrorMessage returns a human-readable error message for validation errors
|
||||||
|
func getErrorMessage(fe validator.FieldError) string {
|
||||||
|
switch fe.Tag() {
|
||||||
|
case "required":
|
||||||
|
return "This field is required"
|
||||||
|
case "email":
|
||||||
|
return "Invalid email format"
|
||||||
|
case "min":
|
||||||
|
return "Value is too short (minimum " + fe.Param() + " characters)"
|
||||||
|
case "max":
|
||||||
|
return "Value is too long (maximum " + fe.Param() + " characters)"
|
||||||
|
case "url":
|
||||||
|
return "Invalid URL format"
|
||||||
|
case "oneof":
|
||||||
|
return "Value must be one of: " + fe.Param()
|
||||||
|
case "uuid":
|
||||||
|
return "Invalid UUID format"
|
||||||
|
case "gte":
|
||||||
|
return "Value must be greater than or equal to " + fe.Param()
|
||||||
|
case "lte":
|
||||||
|
return "Value must be less than or equal to " + fe.Param()
|
||||||
|
case "len":
|
||||||
|
return "Value must be exactly " + fe.Param() + " characters"
|
||||||
|
case "dive":
|
||||||
|
return "Invalid array element"
|
||||||
|
default:
|
||||||
|
return "Invalid value for " + fe.Field()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequiredFields validates that specific fields are present in the request
|
||||||
|
func RequiredFields(fields ...string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var json map[string]interface{}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&json); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||||
|
Error: "Invalid JSON",
|
||||||
|
Message: "Request body contains invalid JSON",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var missingFields []string
|
||||||
|
for _, field := range fields {
|
||||||
|
if _, exists := json[field]; !exists {
|
||||||
|
missingFields = append(missingFields, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(missingFields) > 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||||
|
Error: "Missing Required Fields",
|
||||||
|
Message: "The following required fields are missing: " + strings.Join(missingFields, ", "),
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the parsed JSON for use in handlers
|
||||||
|
c.Set("parsed_json", json)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUUID validates that a URL parameter is a valid UUID
|
||||||
|
func ValidateUUID(param string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
value := c.Param(param)
|
||||||
|
if value == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||||
|
Error: "Missing Parameter",
|
||||||
|
Message: "Required parameter '" + param + "' is missing",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate UUID format
|
||||||
|
if err := validate.Var(value, "uuid"); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||||
|
Error: "Invalid Parameter",
|
||||||
|
Message: "Parameter '" + param + "' must be a valid UUID",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateQueryParams validates query parameters
|
||||||
|
func ValidateQueryParams(rules map[string]string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var errors []ValidationError
|
||||||
|
|
||||||
|
for param, rule := range rules {
|
||||||
|
value := c.Query(param)
|
||||||
|
if value != "" {
|
||||||
|
if err := validate.Var(value, rule); err != nil {
|
||||||
|
for _, err := range err.(validator.ValidationErrors) {
|
||||||
|
errors = append(errors, ValidationError{
|
||||||
|
Field: param,
|
||||||
|
Tag: err.Tag(),
|
||||||
|
Value: err.Param(),
|
||||||
|
Message: getErrorMessage(err),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, ValidationResponse{
|
||||||
|
Error: "Invalid Query Parameters",
|
||||||
|
Message: "One or more query parameters are invalid",
|
||||||
|
Details: errors,
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeInput sanitizes input strings to prevent XSS and injection attacks
|
||||||
|
func SanitizeInput() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// This is a basic implementation - in production you might want to use
|
||||||
|
// a more sophisticated sanitization library like bluemonday
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePermissions validates that permission scopes follow the expected format
|
||||||
|
func ValidatePermissions(c *gin.Context, permissions []string) []ValidationError {
|
||||||
|
var errors []ValidationError
|
||||||
|
|
||||||
|
for i, perm := range permissions {
|
||||||
|
// Check basic format: should contain only alphanumeric, dots, and underscores
|
||||||
|
if err := validate.Var(perm, "required,min=1,max=255,alphanum|contains=.|contains=_"); err != nil {
|
||||||
|
errors = append(errors, ValidationError{
|
||||||
|
Field: "permissions[" + string(rune(i)) + "]",
|
||||||
|
Tag: "format",
|
||||||
|
Value: perm,
|
||||||
|
Message: "Permission scope must contain only alphanumeric characters, dots, and underscores",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for dangerous patterns
|
||||||
|
if strings.Contains(perm, "..") || strings.HasPrefix(perm, ".") || strings.HasSuffix(perm, ".") {
|
||||||
|
errors = append(errors, ValidationError{
|
||||||
|
Field: "permissions[" + string(rune(i)) + "]",
|
||||||
|
Tag: "format",
|
||||||
|
Value: perm,
|
||||||
|
Message: "Permission scope has invalid format",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors
|
||||||
|
}
|
||||||
@ -2,10 +2,14 @@ package postgres
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/kms/api-key-service/internal/domain"
|
"github.com/kms/api-key-service/internal/domain"
|
||||||
"github.com/kms/api-key-service/internal/repository"
|
"github.com/kms/api-key-service/internal/repository"
|
||||||
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PermissionRepository implements the PermissionRepository interface for PostgreSQL
|
// PermissionRepository implements the PermissionRepository interface for PostgreSQL
|
||||||
@ -20,20 +24,116 @@ func NewPermissionRepository(db repository.DatabaseProvider) repository.Permissi
|
|||||||
|
|
||||||
// CreateAvailablePermission creates a new available permission
|
// CreateAvailablePermission creates a new available permission
|
||||||
func (r *PermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error {
|
func (r *PermissionRepository) CreateAvailablePermission(ctx context.Context, permission *domain.AvailablePermission) error {
|
||||||
// TODO: Implement actual permission creation
|
query := `
|
||||||
|
INSERT INTO available_permissions (
|
||||||
|
id, scope, name, description, category, parent_scope,
|
||||||
|
is_system, created_by, updated_by, created_at, updated_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if permission.ID == uuid.Nil {
|
||||||
|
permission.ID = uuid.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.ExecContext(ctx, query,
|
||||||
|
permission.ID,
|
||||||
|
permission.Scope,
|
||||||
|
permission.Name,
|
||||||
|
permission.Description,
|
||||||
|
permission.Category,
|
||||||
|
permission.ParentScope,
|
||||||
|
permission.IsSystem,
|
||||||
|
permission.CreatedBy,
|
||||||
|
permission.UpdatedBy,
|
||||||
|
now,
|
||||||
|
now,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create available permission: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
permission.CreatedAt = now
|
||||||
|
permission.UpdatedAt = now
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAvailablePermission retrieves an available permission by ID
|
// GetAvailablePermission retrieves an available permission by ID
|
||||||
func (r *PermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) {
|
func (r *PermissionRepository) GetAvailablePermission(ctx context.Context, permissionID uuid.UUID) (*domain.AvailablePermission, error) {
|
||||||
// TODO: Implement actual permission retrieval
|
query := `
|
||||||
return nil, nil
|
SELECT id, scope, name, description, category, parent_scope,
|
||||||
|
is_system, created_at, created_by, updated_at, updated_by
|
||||||
|
FROM available_permissions
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
row := db.QueryRowContext(ctx, query, permissionID)
|
||||||
|
|
||||||
|
permission := &domain.AvailablePermission{}
|
||||||
|
err := row.Scan(
|
||||||
|
&permission.ID,
|
||||||
|
&permission.Scope,
|
||||||
|
&permission.Name,
|
||||||
|
&permission.Description,
|
||||||
|
&permission.Category,
|
||||||
|
&permission.ParentScope,
|
||||||
|
&permission.IsSystem,
|
||||||
|
&permission.CreatedAt,
|
||||||
|
&permission.CreatedBy,
|
||||||
|
&permission.UpdatedAt,
|
||||||
|
&permission.UpdatedBy,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("permission with ID '%s' not found", permissionID)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get available permission: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return permission, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAvailablePermissionByScope retrieves an available permission by scope
|
// GetAvailablePermissionByScope retrieves an available permission by scope
|
||||||
func (r *PermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) {
|
func (r *PermissionRepository) GetAvailablePermissionByScope(ctx context.Context, scope string) (*domain.AvailablePermission, error) {
|
||||||
// TODO: Implement actual permission retrieval by scope
|
query := `
|
||||||
return nil, nil
|
SELECT id, scope, name, description, category, parent_scope,
|
||||||
|
is_system, created_at, created_by, updated_at, updated_by
|
||||||
|
FROM available_permissions
|
||||||
|
WHERE scope = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
row := db.QueryRowContext(ctx, query, scope)
|
||||||
|
|
||||||
|
permission := &domain.AvailablePermission{}
|
||||||
|
err := row.Scan(
|
||||||
|
&permission.ID,
|
||||||
|
&permission.Scope,
|
||||||
|
&permission.Name,
|
||||||
|
&permission.Description,
|
||||||
|
&permission.Category,
|
||||||
|
&permission.ParentScope,
|
||||||
|
&permission.IsSystem,
|
||||||
|
&permission.CreatedAt,
|
||||||
|
&permission.CreatedBy,
|
||||||
|
&permission.UpdatedAt,
|
||||||
|
&permission.UpdatedBy,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("permission with scope '%s' not found", scope)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get available permission by scope: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return permission, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAvailablePermissions retrieves available permissions with pagination and filtering
|
// ListAvailablePermissions retrieves available permissions with pagination and filtering
|
||||||
@ -56,9 +156,44 @@ func (r *PermissionRepository) DeleteAvailablePermission(ctx context.Context, pe
|
|||||||
|
|
||||||
// ValidatePermissionScopes checks if all given scopes exist and are valid
|
// ValidatePermissionScopes checks if all given scopes exist and are valid
|
||||||
func (r *PermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) {
|
func (r *PermissionRepository) ValidatePermissionScopes(ctx context.Context, scopes []string) ([]string, error) {
|
||||||
// TODO: Implement actual scope validation
|
if len(scopes) == 0 {
|
||||||
// For now, assume all scopes are valid
|
return []string{}, nil
|
||||||
return []string{}, nil
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT scope
|
||||||
|
FROM available_permissions
|
||||||
|
WHERE scope = ANY($1)
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
rows, err := db.QueryContext(ctx, query, pq.Array(scopes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate permission scopes: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
validScopes := make(map[string]bool)
|
||||||
|
for rows.Next() {
|
||||||
|
var scope string
|
||||||
|
if err := rows.Scan(&scope); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan scope: %w", err)
|
||||||
|
}
|
||||||
|
validScopes[scope] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating scopes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for _, scope := range scopes {
|
||||||
|
if validScopes[scope] {
|
||||||
|
result = append(result, scope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPermissionHierarchy returns all parent and child permissions for given scopes
|
// GetPermissionHierarchy returns all parent and child permissions for given scopes
|
||||||
@ -79,7 +214,56 @@ func NewGrantedPermissionRepository(db repository.DatabaseProvider) repository.G
|
|||||||
|
|
||||||
// GrantPermissions grants multiple permissions to a token
|
// GrantPermissions grants multiple permissions to a token
|
||||||
func (r *GrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error {
|
func (r *GrantedPermissionRepository) GrantPermissions(ctx context.Context, grants []*domain.GrantedPermission) error {
|
||||||
// TODO: Implement actual permission granting
|
if len(grants) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
query := `
|
||||||
|
INSERT INTO granted_permissions (
|
||||||
|
id, token_type, token_id, permission_id, scope, created_by, created_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
ON CONFLICT (token_type, token_id, permission_id) DO NOTHING
|
||||||
|
`
|
||||||
|
|
||||||
|
stmt, err := tx.PrepareContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare statement: %w", err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for _, grant := range grants {
|
||||||
|
if grant.ID == uuid.Nil {
|
||||||
|
grant.ID = uuid.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stmt.ExecContext(ctx,
|
||||||
|
grant.ID,
|
||||||
|
string(grant.TokenType),
|
||||||
|
grant.TokenID,
|
||||||
|
grant.PermissionID,
|
||||||
|
grant.Scope,
|
||||||
|
grant.CreatedBy,
|
||||||
|
now,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to grant permission: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
grant.CreatedAt = now
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,16 +293,72 @@ func (r *GrantedPermissionRepository) RevokeAllPermissions(ctx context.Context,
|
|||||||
|
|
||||||
// HasPermission checks if a token has a specific permission
|
// HasPermission checks if a token has a specific permission
|
||||||
func (r *GrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) {
|
func (r *GrantedPermissionRepository) HasPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scope string) (bool, error) {
|
||||||
// TODO: Implement actual permission checking
|
query := `
|
||||||
|
SELECT 1
|
||||||
|
FROM granted_permissions gp
|
||||||
|
JOIN available_permissions ap ON gp.permission_id = ap.id
|
||||||
|
WHERE gp.token_type = $1
|
||||||
|
AND gp.token_id = $2
|
||||||
|
AND gp.scope = $3
|
||||||
|
AND gp.revoked = false
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
var exists int
|
||||||
|
err := db.QueryRowContext(ctx, query, string(tokenType), tokenID, scope).Scan(&exists)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("failed to check permission: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasAnyPermission checks if a token has any of the specified permissions
|
// HasAnyPermission checks if a token has any of the specified permissions
|
||||||
func (r *GrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) {
|
func (r *GrantedPermissionRepository) HasAnyPermission(ctx context.Context, tokenType domain.TokenType, tokenID uuid.UUID, scopes []string) (map[string]bool, error) {
|
||||||
// TODO: Implement actual permission checking
|
if len(scopes) == 0 {
|
||||||
|
return make(map[string]bool), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT gp.scope
|
||||||
|
FROM granted_permissions gp
|
||||||
|
JOIN available_permissions ap ON gp.permission_id = ap.id
|
||||||
|
WHERE gp.token_type = $1
|
||||||
|
AND gp.token_id = $2
|
||||||
|
AND gp.scope = ANY($3)
|
||||||
|
AND gp.revoked = false
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
rows, err := db.QueryContext(ctx, query, string(tokenType), tokenID, pq.Array(scopes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check permissions: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
result := make(map[string]bool)
|
result := make(map[string]bool)
|
||||||
|
// Initialize all scopes as false
|
||||||
for _, scope := range scopes {
|
for _, scope := range scopes {
|
||||||
|
result[scope] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark found permissions as true
|
||||||
|
for rows.Next() {
|
||||||
|
var scope string
|
||||||
|
if err := rows.Scan(&scope); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan permission scope: %w", err)
|
||||||
|
}
|
||||||
result[scope] = true
|
result[scope] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating permission results: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,26 +57,196 @@ func (r *StaticTokenRepository) Create(ctx context.Context, token *domain.Static
|
|||||||
|
|
||||||
// GetByID retrieves a static token by its ID
|
// GetByID retrieves a static token by its ID
|
||||||
func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
|
func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
|
||||||
// TODO: Implement actual token retrieval
|
query := `
|
||||||
return nil, nil
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||||
|
key_hash, type, created_at, updated_at
|
||||||
|
FROM static_tokens
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
row := db.QueryRowContext(ctx, query, tokenID)
|
||||||
|
|
||||||
|
token := &domain.StaticToken{}
|
||||||
|
var ownerType, ownerName, ownerOwner string
|
||||||
|
|
||||||
|
err := row.Scan(
|
||||||
|
&token.ID,
|
||||||
|
&token.AppID,
|
||||||
|
&ownerType,
|
||||||
|
&ownerName,
|
||||||
|
&ownerOwner,
|
||||||
|
&token.KeyHash,
|
||||||
|
&token.Type,
|
||||||
|
&token.CreatedAt,
|
||||||
|
&token.UpdatedAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("static token with ID '%s' not found", tokenID)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get static token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token.Owner = domain.Owner{
|
||||||
|
Type: domain.OwnerType(ownerType),
|
||||||
|
Name: ownerName,
|
||||||
|
Owner: ownerOwner,
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByKeyHash retrieves a static token by its key hash
|
// GetByKeyHash retrieves a static token by its key hash
|
||||||
func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
|
func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
|
||||||
// TODO: Implement actual token retrieval by hash
|
query := `
|
||||||
return nil, nil
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||||
|
key_hash, type, created_at, updated_at
|
||||||
|
FROM static_tokens
|
||||||
|
WHERE key_hash = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
row := db.QueryRowContext(ctx, query, keyHash)
|
||||||
|
|
||||||
|
token := &domain.StaticToken{}
|
||||||
|
var ownerType, ownerName, ownerOwner string
|
||||||
|
|
||||||
|
err := row.Scan(
|
||||||
|
&token.ID,
|
||||||
|
&token.AppID,
|
||||||
|
&ownerType,
|
||||||
|
&ownerName,
|
||||||
|
&ownerOwner,
|
||||||
|
&token.KeyHash,
|
||||||
|
&token.Type,
|
||||||
|
&token.CreatedAt,
|
||||||
|
&token.UpdatedAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("static token with hash not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get static token by hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token.Owner = domain.Owner{
|
||||||
|
Type: domain.OwnerType(ownerType),
|
||||||
|
Name: ownerName,
|
||||||
|
Owner: ownerOwner,
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByAppID retrieves all static tokens for an application
|
// GetByAppID retrieves all static tokens for an application
|
||||||
func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
|
func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
|
||||||
// TODO: Implement actual token listing
|
query := `
|
||||||
return []*domain.StaticToken{}, nil
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||||
|
key_hash, type, created_at, updated_at
|
||||||
|
FROM static_tokens
|
||||||
|
WHERE app_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
rows, err := db.QueryContext(ctx, query, appID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var tokens []*domain.StaticToken
|
||||||
|
for rows.Next() {
|
||||||
|
token := &domain.StaticToken{}
|
||||||
|
var ownerType, ownerName, ownerOwner string
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&token.ID,
|
||||||
|
&token.AppID,
|
||||||
|
&ownerType,
|
||||||
|
&ownerName,
|
||||||
|
&ownerOwner,
|
||||||
|
&token.KeyHash,
|
||||||
|
&token.Type,
|
||||||
|
&token.CreatedAt,
|
||||||
|
&token.UpdatedAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token.Owner = domain.Owner{
|
||||||
|
Type: domain.OwnerType(ownerType),
|
||||||
|
Name: ownerName,
|
||||||
|
Owner: ownerOwner,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// List retrieves static tokens with pagination
|
// List retrieves static tokens with pagination
|
||||||
func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
|
func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
|
||||||
// TODO: Implement actual token listing
|
query := `
|
||||||
return []*domain.StaticToken{}, nil
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
||||||
|
key_hash, type, created_at, updated_at
|
||||||
|
FROM static_tokens
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $1 OFFSET $2
|
||||||
|
`
|
||||||
|
|
||||||
|
db := r.db.GetDB().(*sql.DB)
|
||||||
|
rows, err := db.QueryContext(ctx, query, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var tokens []*domain.StaticToken
|
||||||
|
for rows.Next() {
|
||||||
|
token := &domain.StaticToken{}
|
||||||
|
var ownerType, ownerName, ownerOwner string
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&token.ID,
|
||||||
|
&token.AppID,
|
||||||
|
&ownerType,
|
||||||
|
&ownerName,
|
||||||
|
&ownerOwner,
|
||||||
|
&token.KeyHash,
|
||||||
|
&token.Type,
|
||||||
|
&token.CreatedAt,
|
||||||
|
&token.UpdatedAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token.Owner = domain.Owner{
|
||||||
|
Type: domain.OwnerType(ownerType),
|
||||||
|
Name: ownerName,
|
||||||
|
Owner: ownerOwner,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes a static token
|
// Delete deletes a static token
|
||||||
|
|||||||
@ -8,17 +8,19 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/kms/api-key-service/internal/crypto"
|
||||||
"github.com/kms/api-key-service/internal/domain"
|
"github.com/kms/api-key-service/internal/domain"
|
||||||
"github.com/kms/api-key-service/internal/repository"
|
"github.com/kms/api-key-service/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenService implements the TokenService interface
|
// tokenService implements the TokenService interface
|
||||||
type tokenService struct {
|
type tokenService struct {
|
||||||
tokenRepo repository.StaticTokenRepository
|
tokenRepo repository.StaticTokenRepository
|
||||||
appRepo repository.ApplicationRepository
|
appRepo repository.ApplicationRepository
|
||||||
permRepo repository.PermissionRepository
|
permRepo repository.PermissionRepository
|
||||||
grantRepo repository.GrantedPermissionRepository
|
grantRepo repository.GrantedPermissionRepository
|
||||||
logger *zap.Logger
|
tokenGen *crypto.TokenGenerator
|
||||||
|
logger *zap.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTokenService creates a new token service
|
// NewTokenService creates a new token service
|
||||||
@ -27,6 +29,7 @@ func NewTokenService(
|
|||||||
appRepo repository.ApplicationRepository,
|
appRepo repository.ApplicationRepository,
|
||||||
permRepo repository.PermissionRepository,
|
permRepo repository.PermissionRepository,
|
||||||
grantRepo repository.GrantedPermissionRepository,
|
grantRepo repository.GrantedPermissionRepository,
|
||||||
|
hmacKey string,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) TokenService {
|
) TokenService {
|
||||||
return &tokenService{
|
return &tokenService{
|
||||||
@ -34,6 +37,7 @@ func NewTokenService(
|
|||||||
appRepo: appRepo,
|
appRepo: appRepo,
|
||||||
permRepo: permRepo,
|
permRepo: permRepo,
|
||||||
grantRepo: grantRepo,
|
grantRepo: grantRepo,
|
||||||
|
tokenGen: crypto.NewTokenGenerator(hmacKey),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -42,10 +46,33 @@ func NewTokenService(
|
|||||||
func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.CreateStaticTokenRequest, userID string) (*domain.CreateStaticTokenResponse, error) {
|
func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.CreateStaticTokenRequest, userID string) (*domain.CreateStaticTokenResponse, error) {
|
||||||
s.logger.Info("Creating static token", zap.String("app_id", req.AppID), zap.String("user_id", userID))
|
s.logger.Info("Creating static token", zap.String("app_id", req.AppID), zap.String("user_id", userID))
|
||||||
|
|
||||||
// TODO: Validate permissions
|
// Validate application exists
|
||||||
// TODO: Validate application exists
|
app, err := s.appRepo.GetByID(ctx, req.AppID)
|
||||||
// TODO: Generate secure token
|
if err != nil {
|
||||||
// TODO: Grant permissions
|
s.logger.Error("Failed to get application", zap.Error(err), zap.String("app_id", req.AppID))
|
||||||
|
return nil, fmt.Errorf("application not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate permissions exist
|
||||||
|
validPermissions, err := s.permRepo.ValidatePermissionScopes(ctx, req.Permissions)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("Failed to validate permissions", zap.Error(err))
|
||||||
|
return nil, fmt.Errorf("failed to validate permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validPermissions) != len(req.Permissions) {
|
||||||
|
s.logger.Warn("Some permissions are invalid",
|
||||||
|
zap.Strings("requested", req.Permissions),
|
||||||
|
zap.Strings("valid", validPermissions))
|
||||||
|
return nil, fmt.Errorf("some requested permissions are invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate secure token
|
||||||
|
tokenInfo, err := s.tokenGen.GenerateTokenWithInfo()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("Failed to generate secure token", zap.Error(err))
|
||||||
|
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
tokenID := uuid.New()
|
tokenID := uuid.New()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@ -54,32 +81,62 @@ func (s *tokenService) CreateStaticToken(ctx context.Context, req *domain.Create
|
|||||||
token := &domain.StaticToken{
|
token := &domain.StaticToken{
|
||||||
ID: tokenID,
|
ID: tokenID,
|
||||||
AppID: req.AppID,
|
AppID: req.AppID,
|
||||||
Owner: domain.Owner{
|
Owner: req.Owner,
|
||||||
Type: domain.OwnerTypeIndividual,
|
KeyHash: tokenInfo.Hash,
|
||||||
Name: userID,
|
Type: "hmac",
|
||||||
Owner: userID,
|
|
||||||
},
|
|
||||||
KeyHash: "placeholder-hash-" + tokenID.String(),
|
|
||||||
Type: "hmac",
|
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the token to the database
|
// Save the token to the database
|
||||||
err := s.tokenRepo.Create(ctx, token)
|
err = s.tokenRepo.Create(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("Failed to create token in database", zap.Error(err), zap.String("token_id", tokenID.String()))
|
s.logger.Error("Failed to create token in database", zap.Error(err), zap.String("token_id", tokenID.String()))
|
||||||
return nil, fmt.Errorf("failed to create token: %w", err)
|
return nil, fmt.Errorf("failed to create token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Grant permissions to the token
|
||||||
|
var grants []*domain.GrantedPermission
|
||||||
|
for _, permScope := range validPermissions {
|
||||||
|
// Get permission by scope to get the ID
|
||||||
|
perm, err := s.permRepo.GetAvailablePermissionByScope(ctx, permScope)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("Failed to get permission by scope", zap.Error(err), zap.String("scope", permScope))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
grant := &domain.GrantedPermission{
|
||||||
|
ID: uuid.New(),
|
||||||
|
TokenType: domain.TokenTypeStatic,
|
||||||
|
TokenID: tokenID,
|
||||||
|
PermissionID: perm.ID,
|
||||||
|
Scope: permScope,
|
||||||
|
CreatedBy: userID,
|
||||||
|
}
|
||||||
|
grants = append(grants, grant)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(grants) > 0 {
|
||||||
|
err = s.grantRepo.GrantPermissions(ctx, grants)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("Failed to grant permissions", zap.Error(err))
|
||||||
|
// Clean up the token if permission granting fails
|
||||||
|
s.tokenRepo.Delete(ctx, tokenID)
|
||||||
|
return nil, fmt.Errorf("failed to grant permissions: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
response := &domain.CreateStaticTokenResponse{
|
response := &domain.CreateStaticTokenResponse{
|
||||||
ID: tokenID,
|
ID: tokenID,
|
||||||
Token: "static-token-placeholder-" + tokenID.String(),
|
Token: tokenInfo.Token, // Return the actual token only once
|
||||||
Permissions: req.Permissions,
|
Permissions: validPermissions,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.logger.Info("Static token created successfully", zap.String("token_id", tokenID.String()))
|
s.logger.Info("Static token created successfully",
|
||||||
|
zap.String("token_id", tokenID.String()),
|
||||||
|
zap.String("app_id", app.AppID),
|
||||||
|
zap.Strings("permissions", validPermissions))
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -267,10 +267,12 @@ test_token_endpoints() {
|
|||||||
-H "X-User-Email: $USER_EMAIL" \
|
-H "X-User-Email: $USER_EMAIL" \
|
||||||
-H "X-User-ID: $USER_ID" \
|
-H "X-User-ID: $USER_ID" \
|
||||||
-d '{
|
-d '{
|
||||||
"name": "Test Static Token for Deletion",
|
"owner": {
|
||||||
"description": "A test static token for deletion test",
|
"type": "individual",
|
||||||
"permissions": ["read"],
|
"name": "Test Token Owner",
|
||||||
"expires_at": "2025-12-31T23:59:59Z"
|
"owner": "test-token@example.com"
|
||||||
|
},
|
||||||
|
"permissions": ["repo.read", "repo.write"]
|
||||||
}' 2>/dev/null || echo -e "\n000")
|
}' 2>/dev/null || echo -e "\n000")
|
||||||
|
|
||||||
local token_status_code=$(echo "$token_response" | tail -n1)
|
local token_status_code=$(echo "$token_response" | tail -n1)
|
||||||
@ -282,10 +284,12 @@ test_token_endpoints() {
|
|||||||
-H "X-User-Email: $USER_EMAIL" \
|
-H "X-User-Email: $USER_EMAIL" \
|
||||||
-H "X-User-ID: $USER_ID" \
|
-H "X-User-ID: $USER_ID" \
|
||||||
-d '{
|
-d '{
|
||||||
"name": "Test Static Token",
|
"owner": {
|
||||||
"description": "A test static token",
|
"type": "individual",
|
||||||
"permissions": ["read"],
|
"name": "Test Token Owner",
|
||||||
"expires_at": "2025-12-31T23:59:59Z"
|
"owner": "test-token@example.com"
|
||||||
|
},
|
||||||
|
"permissions": ["repo.read", "repo.write"]
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# Extract token_id from the first response for deletion test
|
# Extract token_id from the first response for deletion test
|
||||||
|
|||||||
@ -96,7 +96,7 @@ func (suite *IntegrationTestSuite) setupServer() {
|
|||||||
|
|
||||||
// Initialize services
|
// Initialize services
|
||||||
appService := services.NewApplicationService(appRepo, logger)
|
appService := services.NewApplicationService(appRepo, logger)
|
||||||
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, logger)
|
tokenService := services.NewTokenService(tokenRepo, appRepo, permRepo, grantRepo, suite.cfg.GetString("INTERNAL_HMAC_KEY"), logger)
|
||||||
authService := services.NewAuthenticationService(suite.cfg, logger)
|
authService := services.NewAuthenticationService(suite.cfg, logger)
|
||||||
|
|
||||||
// Initialize handlers
|
// Initialize handlers
|
||||||
|
|||||||
@ -460,14 +460,14 @@ func (m *MockPermissionRepository) ValidatePermissionScopes(ctx context.Context,
|
|||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
var invalid []string
|
var valid []string
|
||||||
for _, scope := range scopes {
|
for _, scope := range scopes {
|
||||||
if _, exists := m.scopeIndex[scope]; !exists {
|
if _, exists := m.scopeIndex[scope]; exists {
|
||||||
invalid = append(invalid, scope)
|
valid = append(valid, scope)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return invalid, nil
|
return valid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockPermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {
|
func (m *MockPermissionRepository) GetPermissionHierarchy(ctx context.Context, scopes []string) ([]*domain.AvailablePermission, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user