package postgres import ( "context" "database/sql" "fmt" "strings" "time" "github.com/lib/pq" "github.com/kms/api-key-service/internal/domain" "github.com/kms/api-key-service/internal/repository" ) // ApplicationRepository implements the ApplicationRepository interface for PostgreSQL type ApplicationRepository struct { db repository.DatabaseProvider } // NewApplicationRepository creates a new PostgreSQL application repository func NewApplicationRepository(db repository.DatabaseProvider) repository.ApplicationRepository { return &ApplicationRepository{db: db} } // Create creates a new application func (r *ApplicationRepository) Create(ctx context.Context, app *domain.Application) error { query := ` INSERT INTO applications ( app_id, app_link, type, callback_url, hmac_key, token_renewal_duration, max_token_duration, owner_type, owner_name, owner_owner, created_at, updated_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ` db := r.db.GetDB().(*sql.DB) now := time.Now() // Convert application types to string array typeStrings := make([]string, len(app.Type)) for i, t := range app.Type { typeStrings[i] = string(t) } _, err := db.ExecContext(ctx, query, app.AppID, app.AppLink, pq.Array(typeStrings), app.CallbackURL, app.HMACKey, app.TokenRenewalDuration.Nanoseconds(), app.MaxTokenDuration.Nanoseconds(), string(app.Owner.Type), app.Owner.Name, app.Owner.Owner, now, now, ) if err != nil { if isUniqueViolation(err) { return fmt.Errorf("application with ID '%s' already exists", app.AppID) } return fmt.Errorf("failed to create application: %w", err) } app.CreatedAt = now app.UpdatedAt = now return nil } // GetByID retrieves an application by its ID func (r *ApplicationRepository) GetByID(ctx context.Context, appID string) (*domain.Application, error) { query := ` SELECT app_id, app_link, type, callback_url, hmac_key, token_renewal_duration, max_token_duration, owner_type, owner_name, owner_owner, created_at, updated_at FROM applications WHERE app_id = $1 ` db := r.db.GetDB().(*sql.DB) row := db.QueryRowContext(ctx, query, appID) app := &domain.Application{} var typeStrings pq.StringArray var tokenRenewalNanos, maxTokenNanos int64 var ownerType string err := row.Scan( &app.AppID, &app.AppLink, &typeStrings, &app.CallbackURL, &app.HMACKey, &tokenRenewalNanos, &maxTokenNanos, &ownerType, &app.Owner.Name, &app.Owner.Owner, &app.CreatedAt, &app.UpdatedAt, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("application with ID '%s' not found", appID) } return nil, fmt.Errorf("failed to get application: %w", err) } // Convert string array to application types app.Type = make([]domain.ApplicationType, len(typeStrings)) for i, t := range typeStrings { app.Type[i] = domain.ApplicationType(t) } // Convert nanoseconds to duration app.TokenRenewalDuration = time.Duration(tokenRenewalNanos) app.MaxTokenDuration = time.Duration(maxTokenNanos) // Convert owner type app.Owner.Type = domain.OwnerType(ownerType) return app, nil } // List retrieves applications with pagination func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]*domain.Application, error) { query := ` SELECT app_id, app_link, type, callback_url, hmac_key, token_renewal_duration, max_token_duration, owner_type, owner_name, owner_owner, created_at, updated_at FROM applications ORDER BY created_at DESC LIMIT $1 OFFSET $2 ` db := r.db.GetDB().(*sql.DB) rows, err := db.QueryContext(ctx, query, limit, offset) if err != nil { return nil, fmt.Errorf("failed to list applications: %w", err) } defer rows.Close() var applications []*domain.Application for rows.Next() { app := &domain.Application{} var typeStrings pq.StringArray var tokenRenewalNanos, maxTokenNanos int64 var ownerType string err := rows.Scan( &app.AppID, &app.AppLink, &typeStrings, &app.CallbackURL, &app.HMACKey, &tokenRenewalNanos, &maxTokenNanos, &ownerType, &app.Owner.Name, &app.Owner.Owner, &app.CreatedAt, &app.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("failed to scan application: %w", err) } // Convert string array to application types app.Type = make([]domain.ApplicationType, len(typeStrings)) for i, t := range typeStrings { app.Type[i] = domain.ApplicationType(t) } // Convert nanoseconds to duration app.TokenRenewalDuration = time.Duration(tokenRenewalNanos) app.MaxTokenDuration = time.Duration(maxTokenNanos) // Convert owner type app.Owner.Type = domain.OwnerType(ownerType) applications = append(applications, app) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to iterate applications: %w", err) } return applications, nil } // Update updates an existing application func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) { // Build dynamic update query var setParts []string var args []interface{} argIndex := 1 if updates.AppLink != nil { setParts = append(setParts, fmt.Sprintf("app_link = $%d", argIndex)) args = append(args, *updates.AppLink) argIndex++ } if updates.Type != nil { typeStrings := make([]string, len(*updates.Type)) for i, t := range *updates.Type { typeStrings[i] = string(t) } setParts = append(setParts, fmt.Sprintf("type = $%d", argIndex)) args = append(args, pq.Array(typeStrings)) argIndex++ } if updates.CallbackURL != nil { setParts = append(setParts, fmt.Sprintf("callback_url = $%d", argIndex)) args = append(args, *updates.CallbackURL) argIndex++ } if updates.HMACKey != nil { setParts = append(setParts, fmt.Sprintf("hmac_key = $%d", argIndex)) args = append(args, *updates.HMACKey) argIndex++ } if updates.TokenRenewalDuration != nil { setParts = append(setParts, fmt.Sprintf("token_renewal_duration = $%d", argIndex)) args = append(args, updates.TokenRenewalDuration.Nanoseconds()) argIndex++ } if updates.MaxTokenDuration != nil { setParts = append(setParts, fmt.Sprintf("max_token_duration = $%d", argIndex)) args = append(args, updates.MaxTokenDuration.Nanoseconds()) argIndex++ } if updates.Owner != nil { setParts = append(setParts, fmt.Sprintf("owner_type = $%d", argIndex)) args = append(args, string(updates.Owner.Type)) argIndex++ setParts = append(setParts, fmt.Sprintf("owner_name = $%d", argIndex)) args = append(args, updates.Owner.Name) argIndex++ setParts = append(setParts, fmt.Sprintf("owner_owner = $%d", argIndex)) args = append(args, updates.Owner.Owner) argIndex++ } if len(setParts) == 0 { return r.GetByID(ctx, appID) // No updates, return current state } // Always update the updated_at field setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex)) args = append(args, time.Now()) argIndex++ // Add WHERE clause args = append(args, appID) query := fmt.Sprintf(` UPDATE applications SET %s WHERE app_id = $%d `, strings.Join(setParts, ", "), argIndex) db := r.db.GetDB().(*sql.DB) result, err := db.ExecContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to update application: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return nil, fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return nil, fmt.Errorf("application with ID '%s' not found", appID) } // Return updated application return r.GetByID(ctx, appID) } // Delete deletes an application func (r *ApplicationRepository) Delete(ctx context.Context, appID string) error { query := `DELETE FROM applications WHERE app_id = $1` db := r.db.GetDB().(*sql.DB) result, err := db.ExecContext(ctx, query, appID) if err != nil { return fmt.Errorf("failed to delete application: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return fmt.Errorf("application with ID '%s' not found", appID) } return nil } // Exists checks if an application exists func (r *ApplicationRepository) Exists(ctx context.Context, appID string) (bool, error) { query := `SELECT 1 FROM applications WHERE app_id = $1` db := r.db.GetDB().(*sql.DB) var exists int err := db.QueryRowContext(ctx, query, appID).Scan(&exists) if err != nil { if err == sql.ErrNoRows { return false, nil } return false, fmt.Errorf("failed to check application existence: %w", err) } return true, nil } // isUniqueViolation checks if the error is a unique constraint violation func isUniqueViolation(err error) bool { if pqErr, ok := err.(*pq.Error); ok { return pqErr.Code == "23505" // unique_violation } return false }