package postgres import ( "context" "database/sql" "fmt" "strings" "time" "github.com/lib/pq" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/repository" ) // ApplicationRepository implements the ApplicationRepository interface for PostgreSQL type ApplicationRepository struct { db repository.DatabaseProvider } // NewApplicationRepository creates a new PostgreSQL application repository func NewApplicationRepository(db repository.DatabaseProvider) repository.ApplicationRepository { return &ApplicationRepository{db: db} } // Create creates a new application func (r *ApplicationRepository) Create(ctx context.Context, app *domain.Application) error { query := ` INSERT INTO applications ( app_id, app_link, type, callback_url, hmac_key, token_prefix, token_renewal_duration, max_token_duration, owner_type, owner_name, owner_owner, created_at, updated_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ` db := r.db.GetDB().(*sql.DB) now := time.Now() // Convert application types to string array typeStrings := make([]string, len(app.Type)) for i, t := range app.Type { typeStrings[i] = string(t) } _, err := db.ExecContext(ctx, query, app.AppID, app.AppLink, pq.Array(typeStrings), app.CallbackURL, app.HMACKey, app.TokenPrefix, app.TokenRenewalDuration.Duration.Nanoseconds(), app.MaxTokenDuration.Duration.Nanoseconds(), string(app.Owner.Type), app.Owner.Name, app.Owner.Owner, now, now, ) if err != nil { if isUniqueViolation(err) { return fmt.Errorf("application with ID '%s' already exists", app.AppID) } return fmt.Errorf("failed to create application: %w", err) } app.CreatedAt = now app.UpdatedAt = now return nil } // GetByID retrieves an application by its ID func (r *ApplicationRepository) GetByID(ctx context.Context, appID string) (*domain.Application, error) { query := ` SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix, token_renewal_duration, max_token_duration, owner_type, owner_name, owner_owner, created_at, updated_at FROM applications WHERE app_id = $1 ` db := r.db.GetDB().(*sql.DB) row := db.QueryRowContext(ctx, query, appID) app := &domain.Application{} var typeStrings pq.StringArray var tokenRenewalNanos, maxTokenNanos int64 var ownerType string err := row.Scan( &app.AppID, &app.AppLink, &typeStrings, &app.CallbackURL, &app.HMACKey, &app.TokenPrefix, &tokenRenewalNanos, &maxTokenNanos, &ownerType, &app.Owner.Name, &app.Owner.Owner, &app.CreatedAt, &app.UpdatedAt, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("application with ID '%s' not found", appID) } return nil, fmt.Errorf("failed to get application: %w", err) } // Convert string array to application types app.Type = make([]domain.ApplicationType, len(typeStrings)) for i, t := range typeStrings { app.Type[i] = domain.ApplicationType(t) } // Convert nanoseconds to duration app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)} app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)} // Convert owner type app.Owner.Type = domain.OwnerType(ownerType) return app, nil } // List retrieves applications with pagination func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]*domain.Application, error) { query := ` SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix, token_renewal_duration, max_token_duration, owner_type, owner_name, owner_owner, created_at, updated_at FROM applications ORDER BY created_at DESC LIMIT $1 OFFSET $2 ` db := r.db.GetDB().(*sql.DB) rows, err := db.QueryContext(ctx, query, limit, offset) if err != nil { return nil, fmt.Errorf("failed to list applications: %w", err) } defer rows.Close() var applications []*domain.Application for rows.Next() { app := &domain.Application{} var typeStrings pq.StringArray var tokenRenewalNanos, maxTokenNanos int64 var ownerType string err := rows.Scan( &app.AppID, &app.AppLink, &typeStrings, &app.CallbackURL, &app.HMACKey, &app.TokenPrefix, &tokenRenewalNanos, &maxTokenNanos, &ownerType, &app.Owner.Name, &app.Owner.Owner, &app.CreatedAt, &app.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("failed to scan application: %w", err) } // Convert string array to application types app.Type = make([]domain.ApplicationType, len(typeStrings)) for i, t := range typeStrings { app.Type[i] = domain.ApplicationType(t) } // Convert nanoseconds to duration app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)} app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)} // Convert owner type app.Owner.Type = domain.OwnerType(ownerType) applications = append(applications, app) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to iterate applications: %w", err) } return applications, nil } // Update updates an existing application func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) { // Build secure dynamic update query using a whitelist approach var setParts []string var args []interface{} argIndex := 1 // Whitelist of allowed fields to prevent SQL injection allowedFields := map[string]string{ "app_link": "app_link", "type": "type", "callback_url": "callback_url", "hmac_key": "hmac_key", "token_prefix": "token_prefix", "token_renewal_duration": "token_renewal_duration", "max_token_duration": "max_token_duration", "owner_type": "owner_type", "owner_name": "owner_name", "owner_owner": "owner_owner", } if updates.AppLink != nil { if field, ok := allowedFields["app_link"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, *updates.AppLink) argIndex++ } } if updates.Type != nil { if field, ok := allowedFields["type"]; ok { typeStrings := make([]string, len(*updates.Type)) for i, t := range *updates.Type { typeStrings[i] = string(t) } setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, pq.Array(typeStrings)) argIndex++ } } if updates.CallbackURL != nil { if field, ok := allowedFields["callback_url"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, *updates.CallbackURL) argIndex++ } } if updates.HMACKey != nil { if field, ok := allowedFields["hmac_key"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, *updates.HMACKey) argIndex++ } } if updates.TokenPrefix != nil { if field, ok := allowedFields["token_prefix"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, *updates.TokenPrefix) argIndex++ } } if updates.TokenRenewalDuration != nil { if field, ok := allowedFields["token_renewal_duration"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds()) argIndex++ } } if updates.MaxTokenDuration != nil { if field, ok := allowedFields["max_token_duration"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds()) argIndex++ } } if updates.Owner != nil { if field, ok := allowedFields["owner_type"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, string(updates.Owner.Type)) argIndex++ } if field, ok := allowedFields["owner_name"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, updates.Owner.Name) argIndex++ } if field, ok := allowedFields["owner_owner"]; ok { setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex)) args = append(args, updates.Owner.Owner) argIndex++ } } if len(setParts) == 0 { return r.GetByID(ctx, appID) // No updates, return current state } // Always update the updated_at field - using literal field name for security setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex)) args = append(args, time.Now()) argIndex++ // Add WHERE clause parameter args = append(args, appID) // Build the final query with properly parameterized placeholders query := fmt.Sprintf(` UPDATE applications SET %s WHERE app_id = $%d `, strings.Join(setParts, ", "), argIndex) db := r.db.GetDB().(*sql.DB) result, err := db.ExecContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to update application: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return nil, fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return nil, fmt.Errorf("application with ID '%s' not found", appID) } // Return updated application return r.GetByID(ctx, appID) } // Delete deletes an application func (r *ApplicationRepository) Delete(ctx context.Context, appID string) error { query := `DELETE FROM applications WHERE app_id = $1` db := r.db.GetDB().(*sql.DB) result, err := db.ExecContext(ctx, query, appID) if err != nil { return fmt.Errorf("failed to delete application: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return fmt.Errorf("application with ID '%s' not found", appID) } return nil } // Exists checks if an application exists func (r *ApplicationRepository) Exists(ctx context.Context, appID string) (bool, error) { query := `SELECT 1 FROM applications WHERE app_id = $1` db := r.db.GetDB().(*sql.DB) var exists int err := db.QueryRowContext(ctx, query, appID).Scan(&exists) if err != nil { if err == sql.ErrNoRows { return false, nil } return false, fmt.Errorf("failed to check application existence: %w", err) } return true, nil } // isUniqueViolation checks if the error is a unique constraint violation func isUniqueViolation(err error) bool { if pqErr, ok := err.(*pq.Error); ok { return pqErr.Code == "23505" // unique_violation } return false }