388 lines
10 KiB
Go
388 lines
10 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/lib/pq"
|
|
"github.com/kms/api-key-service/internal/domain"
|
|
"github.com/kms/api-key-service/internal/repository"
|
|
)
|
|
|
|
// ApplicationRepository implements the ApplicationRepository interface for PostgreSQL
|
|
type ApplicationRepository struct {
|
|
db repository.DatabaseProvider
|
|
}
|
|
|
|
// NewApplicationRepository creates a new PostgreSQL application repository
|
|
func NewApplicationRepository(db repository.DatabaseProvider) repository.ApplicationRepository {
|
|
return &ApplicationRepository{db: db}
|
|
}
|
|
|
|
// Create creates a new application
|
|
func (r *ApplicationRepository) Create(ctx context.Context, app *domain.Application) error {
|
|
query := `
|
|
INSERT INTO applications (
|
|
app_id, app_link, type, callback_url, hmac_key, token_prefix,
|
|
token_renewal_duration, max_token_duration,
|
|
owner_type, owner_name, owner_owner,
|
|
created_at, updated_at
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
now := time.Now()
|
|
|
|
// Convert application types to string array
|
|
typeStrings := make([]string, len(app.Type))
|
|
for i, t := range app.Type {
|
|
typeStrings[i] = string(t)
|
|
}
|
|
|
|
_, err := db.ExecContext(ctx, query,
|
|
app.AppID,
|
|
app.AppLink,
|
|
pq.Array(typeStrings),
|
|
app.CallbackURL,
|
|
app.HMACKey,
|
|
app.TokenPrefix,
|
|
app.TokenRenewalDuration.Duration.Nanoseconds(),
|
|
app.MaxTokenDuration.Duration.Nanoseconds(),
|
|
string(app.Owner.Type),
|
|
app.Owner.Name,
|
|
app.Owner.Owner,
|
|
now,
|
|
now,
|
|
)
|
|
|
|
if err != nil {
|
|
if isUniqueViolation(err) {
|
|
return fmt.Errorf("application with ID '%s' already exists", app.AppID)
|
|
}
|
|
return fmt.Errorf("failed to create application: %w", err)
|
|
}
|
|
|
|
app.CreatedAt = now
|
|
app.UpdatedAt = now
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetByID retrieves an application by its ID
|
|
func (r *ApplicationRepository) GetByID(ctx context.Context, appID string) (*domain.Application, error) {
|
|
query := `
|
|
SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix,
|
|
token_renewal_duration, max_token_duration,
|
|
owner_type, owner_name, owner_owner,
|
|
created_at, updated_at
|
|
FROM applications
|
|
WHERE app_id = $1
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
row := db.QueryRowContext(ctx, query, appID)
|
|
|
|
app := &domain.Application{}
|
|
var typeStrings pq.StringArray
|
|
var tokenRenewalNanos, maxTokenNanos int64
|
|
var ownerType string
|
|
|
|
err := row.Scan(
|
|
&app.AppID,
|
|
&app.AppLink,
|
|
&typeStrings,
|
|
&app.CallbackURL,
|
|
&app.HMACKey,
|
|
&app.TokenPrefix,
|
|
&tokenRenewalNanos,
|
|
&maxTokenNanos,
|
|
&ownerType,
|
|
&app.Owner.Name,
|
|
&app.Owner.Owner,
|
|
&app.CreatedAt,
|
|
&app.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("application with ID '%s' not found", appID)
|
|
}
|
|
return nil, fmt.Errorf("failed to get application: %w", err)
|
|
}
|
|
|
|
// Convert string array to application types
|
|
app.Type = make([]domain.ApplicationType, len(typeStrings))
|
|
for i, t := range typeStrings {
|
|
app.Type[i] = domain.ApplicationType(t)
|
|
}
|
|
|
|
// Convert nanoseconds to duration
|
|
app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)}
|
|
app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)}
|
|
|
|
// Convert owner type
|
|
app.Owner.Type = domain.OwnerType(ownerType)
|
|
|
|
return app, nil
|
|
}
|
|
|
|
// List retrieves applications with pagination
|
|
func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]*domain.Application, error) {
|
|
query := `
|
|
SELECT app_id, app_link, type, callback_url, hmac_key, token_prefix,
|
|
token_renewal_duration, max_token_duration,
|
|
owner_type, owner_name, owner_owner,
|
|
created_at, updated_at
|
|
FROM applications
|
|
ORDER BY created_at DESC
|
|
LIMIT $1 OFFSET $2
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
rows, err := db.QueryContext(ctx, query, limit, offset)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list applications: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var applications []*domain.Application
|
|
|
|
for rows.Next() {
|
|
app := &domain.Application{}
|
|
var typeStrings pq.StringArray
|
|
var tokenRenewalNanos, maxTokenNanos int64
|
|
var ownerType string
|
|
|
|
err := rows.Scan(
|
|
&app.AppID,
|
|
&app.AppLink,
|
|
&typeStrings,
|
|
&app.CallbackURL,
|
|
&app.HMACKey,
|
|
&app.TokenPrefix,
|
|
&tokenRenewalNanos,
|
|
&maxTokenNanos,
|
|
&ownerType,
|
|
&app.Owner.Name,
|
|
&app.Owner.Owner,
|
|
&app.CreatedAt,
|
|
&app.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan application: %w", err)
|
|
}
|
|
|
|
// Convert string array to application types
|
|
app.Type = make([]domain.ApplicationType, len(typeStrings))
|
|
for i, t := range typeStrings {
|
|
app.Type[i] = domain.ApplicationType(t)
|
|
}
|
|
|
|
// Convert nanoseconds to duration
|
|
app.TokenRenewalDuration = domain.Duration{Duration: time.Duration(tokenRenewalNanos)}
|
|
app.MaxTokenDuration = domain.Duration{Duration: time.Duration(maxTokenNanos)}
|
|
|
|
// Convert owner type
|
|
app.Owner.Type = domain.OwnerType(ownerType)
|
|
|
|
applications = append(applications, app)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("failed to iterate applications: %w", err)
|
|
}
|
|
|
|
return applications, nil
|
|
}
|
|
|
|
// Update updates an existing application
|
|
func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) {
|
|
// Build secure dynamic update query using a whitelist approach
|
|
var setParts []string
|
|
var args []interface{}
|
|
argIndex := 1
|
|
|
|
// Whitelist of allowed fields to prevent SQL injection
|
|
allowedFields := map[string]string{
|
|
"app_link": "app_link",
|
|
"type": "type",
|
|
"callback_url": "callback_url",
|
|
"hmac_key": "hmac_key",
|
|
"token_prefix": "token_prefix",
|
|
"token_renewal_duration": "token_renewal_duration",
|
|
"max_token_duration": "max_token_duration",
|
|
"owner_type": "owner_type",
|
|
"owner_name": "owner_name",
|
|
"owner_owner": "owner_owner",
|
|
}
|
|
|
|
if updates.AppLink != nil {
|
|
if field, ok := allowedFields["app_link"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, *updates.AppLink)
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if updates.Type != nil {
|
|
if field, ok := allowedFields["type"]; ok {
|
|
typeStrings := make([]string, len(*updates.Type))
|
|
for i, t := range *updates.Type {
|
|
typeStrings[i] = string(t)
|
|
}
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, pq.Array(typeStrings))
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if updates.CallbackURL != nil {
|
|
if field, ok := allowedFields["callback_url"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, *updates.CallbackURL)
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if updates.HMACKey != nil {
|
|
if field, ok := allowedFields["hmac_key"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, *updates.HMACKey)
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if updates.TokenPrefix != nil {
|
|
if field, ok := allowedFields["token_prefix"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, *updates.TokenPrefix)
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if updates.TokenRenewalDuration != nil {
|
|
if field, ok := allowedFields["token_renewal_duration"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if updates.MaxTokenDuration != nil {
|
|
if field, ok := allowedFields["max_token_duration"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if updates.Owner != nil {
|
|
if field, ok := allowedFields["owner_type"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, string(updates.Owner.Type))
|
|
argIndex++
|
|
}
|
|
|
|
if field, ok := allowedFields["owner_name"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, updates.Owner.Name)
|
|
argIndex++
|
|
}
|
|
|
|
if field, ok := allowedFields["owner_owner"]; ok {
|
|
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
|
args = append(args, updates.Owner.Owner)
|
|
argIndex++
|
|
}
|
|
}
|
|
|
|
if len(setParts) == 0 {
|
|
return r.GetByID(ctx, appID) // No updates, return current state
|
|
}
|
|
|
|
// Always update the updated_at field - using literal field name for security
|
|
setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex))
|
|
args = append(args, time.Now())
|
|
argIndex++
|
|
|
|
// Add WHERE clause parameter
|
|
args = append(args, appID)
|
|
|
|
// Build the final query with properly parameterized placeholders
|
|
query := fmt.Sprintf(`
|
|
UPDATE applications
|
|
SET %s
|
|
WHERE app_id = $%d
|
|
`, strings.Join(setParts, ", "), argIndex)
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
result, err := db.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to update application: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return nil, fmt.Errorf("application with ID '%s' not found", appID)
|
|
}
|
|
|
|
// Return updated application
|
|
return r.GetByID(ctx, appID)
|
|
}
|
|
|
|
// Delete deletes an application
|
|
func (r *ApplicationRepository) Delete(ctx context.Context, appID string) error {
|
|
query := `DELETE FROM applications WHERE app_id = $1`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
result, err := db.ExecContext(ctx, query, appID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete application: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("application with ID '%s' not found", appID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Exists checks if an application exists
|
|
func (r *ApplicationRepository) Exists(ctx context.Context, appID string) (bool, error) {
|
|
query := `SELECT 1 FROM applications WHERE app_id = $1`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
var exists int
|
|
err := db.QueryRowContext(ctx, query, appID).Scan(&exists)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("failed to check application existence: %w", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// isUniqueViolation checks if the error is a unique constraint violation
|
|
func isUniqueViolation(err error) bool {
|
|
if pqErr, ok := err.(*pq.Error); ok {
|
|
return pqErr.Code == "23505" // unique_violation
|
|
}
|
|
return false
|
|
}
|