-
This commit is contained in:
@ -168,9 +168,6 @@ type DatabaseProvider interface {
|
||||
|
||||
// BeginTx starts a database transaction
|
||||
BeginTx(ctx context.Context) (TransactionProvider, error)
|
||||
|
||||
// Migrate runs database migrations
|
||||
Migrate(ctx context.Context, migrationPath string) error
|
||||
}
|
||||
|
||||
// TransactionProvider defines the interface for database transaction operations
|
||||
|
||||
@ -201,83 +201,118 @@ func (r *ApplicationRepository) List(ctx context.Context, limit, offset int) ([]
|
||||
|
||||
// Update updates an existing application
|
||||
func (r *ApplicationRepository) Update(ctx context.Context, appID string, updates *domain.UpdateApplicationRequest) (*domain.Application, error) {
|
||||
// Build dynamic update query
|
||||
// Build secure dynamic update query using a whitelist approach
|
||||
var setParts []string
|
||||
var args []interface{}
|
||||
argIndex := 1
|
||||
|
||||
// Whitelist of allowed fields to prevent SQL injection
|
||||
allowedFields := map[string]string{
|
||||
"app_link": "app_link",
|
||||
"type": "type",
|
||||
"callback_url": "callback_url",
|
||||
"hmac_key": "hmac_key",
|
||||
"token_prefix": "token_prefix",
|
||||
"token_renewal_duration": "token_renewal_duration",
|
||||
"max_token_duration": "max_token_duration",
|
||||
"owner_type": "owner_type",
|
||||
"owner_name": "owner_name",
|
||||
"owner_owner": "owner_owner",
|
||||
}
|
||||
|
||||
if updates.AppLink != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("app_link = $%d", argIndex))
|
||||
args = append(args, *updates.AppLink)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["app_link"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.AppLink)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.Type != nil {
|
||||
typeStrings := make([]string, len(*updates.Type))
|
||||
for i, t := range *updates.Type {
|
||||
typeStrings[i] = string(t)
|
||||
if field, ok := allowedFields["type"]; ok {
|
||||
typeStrings := make([]string, len(*updates.Type))
|
||||
for i, t := range *updates.Type {
|
||||
typeStrings[i] = string(t)
|
||||
}
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, pq.Array(typeStrings))
|
||||
argIndex++
|
||||
}
|
||||
setParts = append(setParts, fmt.Sprintf("type = $%d", argIndex))
|
||||
args = append(args, pq.Array(typeStrings))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if updates.CallbackURL != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("callback_url = $%d", argIndex))
|
||||
args = append(args, *updates.CallbackURL)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["callback_url"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.CallbackURL)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.HMACKey != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("hmac_key = $%d", argIndex))
|
||||
args = append(args, *updates.HMACKey)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["hmac_key"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.HMACKey)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.TokenPrefix != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("token_prefix = $%d", argIndex))
|
||||
args = append(args, *updates.TokenPrefix)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["token_prefix"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, *updates.TokenPrefix)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.TokenRenewalDuration != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("token_renewal_duration = $%d", argIndex))
|
||||
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
if field, ok := allowedFields["token_renewal_duration"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.TokenRenewalDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.MaxTokenDuration != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("max_token_duration = $%d", argIndex))
|
||||
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
if field, ok := allowedFields["max_token_duration"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.MaxTokenDuration.Duration.Nanoseconds())
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if updates.Owner != nil {
|
||||
setParts = append(setParts, fmt.Sprintf("owner_type = $%d", argIndex))
|
||||
args = append(args, string(updates.Owner.Type))
|
||||
argIndex++
|
||||
if field, ok := allowedFields["owner_type"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, string(updates.Owner.Type))
|
||||
argIndex++
|
||||
}
|
||||
|
||||
setParts = append(setParts, fmt.Sprintf("owner_name = $%d", argIndex))
|
||||
args = append(args, updates.Owner.Name)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["owner_name"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.Owner.Name)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
setParts = append(setParts, fmt.Sprintf("owner_owner = $%d", argIndex))
|
||||
args = append(args, updates.Owner.Owner)
|
||||
argIndex++
|
||||
if field, ok := allowedFields["owner_owner"]; ok {
|
||||
setParts = append(setParts, fmt.Sprintf("%s = $%d", field, argIndex))
|
||||
args = append(args, updates.Owner.Owner)
|
||||
argIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if len(setParts) == 0 {
|
||||
return r.GetByID(ctx, appID) // No updates, return current state
|
||||
}
|
||||
|
||||
// Always update the updated_at field
|
||||
// Always update the updated_at field - using literal field name for security
|
||||
setParts = append(setParts, fmt.Sprintf("updated_at = $%d", argIndex))
|
||||
args = append(args, time.Now())
|
||||
argIndex++
|
||||
|
||||
// Add WHERE clause
|
||||
// Add WHERE clause parameter
|
||||
args = append(args, appID)
|
||||
|
||||
// Build the final query with properly parameterized placeholders
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE applications
|
||||
SET %s
|
||||
|
||||
Reference in New Issue
Block a user