319 lines
10 KiB
Go
319 lines
10 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/RyanCopley/skybridge/faas/internal/domain"
|
|
"github.com/RyanCopley/skybridge/faas/internal/repository"
|
|
)
|
|
|
|
type executionRepository struct {
|
|
db *sql.DB
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// Helper function to convert time.Duration to PostgreSQL interval
|
|
func durationToInterval(d time.Duration) interface{} {
|
|
if d == 0 {
|
|
return nil
|
|
}
|
|
// Convert nanoseconds to PostgreSQL interval format
|
|
seconds := float64(d) / float64(time.Second)
|
|
return fmt.Sprintf("%.9f seconds", seconds)
|
|
}
|
|
|
|
// Helper function to convert PostgreSQL interval to time.Duration
|
|
func intervalToDuration(interval interface{}) (time.Duration, error) {
|
|
if interval == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
switch v := interval.(type) {
|
|
case string:
|
|
if v == "" {
|
|
return 0, nil
|
|
}
|
|
// Try to parse as PostgreSQL interval
|
|
// For now, we'll use a simple approach - parse common formats
|
|
duration, err := time.ParseDuration(v)
|
|
if err == nil {
|
|
return duration, nil
|
|
}
|
|
// Handle PostgreSQL interval format like "00:00:05.123456"
|
|
var hours, minutes int
|
|
var seconds float64
|
|
if n, err := fmt.Sscanf(v, "%d:%d:%f", &hours, &minutes, &seconds); n == 3 && err == nil {
|
|
return time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + time.Duration(seconds*float64(time.Second)), nil
|
|
}
|
|
return 0, fmt.Errorf("unable to parse interval: %s", v)
|
|
case []byte:
|
|
return intervalToDuration(string(v))
|
|
default:
|
|
return 0, fmt.Errorf("unexpected interval type: %T", interval)
|
|
}
|
|
}
|
|
|
|
// Helper function to handle JSON fields
|
|
func jsonField(data json.RawMessage) interface{} {
|
|
if len(data) == 0 || data == nil {
|
|
return "{}" // Return empty JSON string instead of nil or RawMessage
|
|
}
|
|
return string(data) // Convert RawMessage to string for database operations
|
|
}
|
|
|
|
func NewExecutionRepository(db *sql.DB, logger *zap.Logger) repository.ExecutionRepository {
|
|
return &executionRepository{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (r *executionRepository) Create(ctx context.Context, execution *domain.FunctionExecution) (*domain.FunctionExecution, error) {
|
|
query := `
|
|
INSERT INTO executions (id, function_id, status, input, executor_id, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING created_at`
|
|
|
|
err := r.db.QueryRowContext(ctx, query,
|
|
execution.ID, execution.FunctionID, execution.Status, jsonField(execution.Input),
|
|
execution.ExecutorID, execution.CreatedAt,
|
|
).Scan(&execution.CreatedAt)
|
|
|
|
if err != nil {
|
|
r.logger.Error("Failed to create execution", zap.Error(err))
|
|
return nil, fmt.Errorf("failed to create execution: %w", err)
|
|
}
|
|
|
|
return execution, nil
|
|
}
|
|
|
|
func (r *executionRepository) GetByID(ctx context.Context, id uuid.UUID) (*domain.FunctionExecution, error) {
|
|
query := `
|
|
SELECT id, function_id, status, input, output, error, duration, memory_used,
|
|
container_id, executor_id, created_at, started_at, completed_at
|
|
FROM executions WHERE id = $1`
|
|
|
|
execution := &domain.FunctionExecution{}
|
|
var durationInterval sql.NullString
|
|
|
|
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
|
&execution.ID, &execution.FunctionID, &execution.Status, &execution.Input,
|
|
&execution.Output, &execution.Error, &durationInterval, &execution.MemoryUsed,
|
|
&execution.ContainerID, &execution.ExecutorID, &execution.CreatedAt,
|
|
&execution.StartedAt, &execution.CompletedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("execution not found")
|
|
}
|
|
r.logger.Error("Failed to get execution by ID", zap.String("id", id.String()), zap.Error(err))
|
|
return nil, fmt.Errorf("failed to get execution: %w", err)
|
|
}
|
|
|
|
// Convert duration from PostgreSQL interval
|
|
if durationInterval.Valid {
|
|
duration, err := intervalToDuration(durationInterval.String)
|
|
if err != nil {
|
|
r.logger.Warn("Failed to parse duration interval", zap.String("interval", durationInterval.String), zap.Error(err))
|
|
} else {
|
|
execution.Duration = duration
|
|
}
|
|
}
|
|
|
|
return execution, nil
|
|
}
|
|
|
|
func (r *executionRepository) Update(ctx context.Context, id uuid.UUID, execution *domain.FunctionExecution) (*domain.FunctionExecution, error) {
|
|
query := `
|
|
UPDATE executions
|
|
SET status = $2, output = $3, error = $4, duration = $5, memory_used = $6,
|
|
container_id = $7, started_at = $8, completed_at = $9
|
|
WHERE id = $1`
|
|
|
|
_, err := r.db.ExecContext(ctx, query,
|
|
id, execution.Status, jsonField(execution.Output), execution.Error,
|
|
durationToInterval(execution.Duration), execution.MemoryUsed, execution.ContainerID,
|
|
execution.StartedAt, execution.CompletedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
r.logger.Error("Failed to update execution", zap.String("id", id.String()), zap.Error(err))
|
|
return nil, fmt.Errorf("failed to update execution: %w", err)
|
|
}
|
|
|
|
// Return updated execution
|
|
return r.GetByID(ctx, id)
|
|
}
|
|
|
|
func (r *executionRepository) Delete(ctx context.Context, id uuid.UUID) error {
|
|
query := `DELETE FROM executions WHERE id = $1`
|
|
|
|
result, err := r.db.ExecContext(ctx, query, id)
|
|
if err != nil {
|
|
r.logger.Error("Failed to delete execution", zap.String("id", id.String()), zap.Error(err))
|
|
return fmt.Errorf("failed to delete execution: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get affected rows: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("execution not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *executionRepository) List(ctx context.Context, functionID *uuid.UUID, limit, offset int) ([]*domain.FunctionExecution, error) {
|
|
var query string
|
|
var args []interface{}
|
|
|
|
if functionID != nil {
|
|
query = `
|
|
SELECT id, function_id, status, input, output, error, duration, memory_used,
|
|
container_id, executor_id, created_at, started_at, completed_at
|
|
FROM executions WHERE function_id = $1
|
|
ORDER BY created_at DESC LIMIT $2 OFFSET $3`
|
|
args = []interface{}{*functionID, limit, offset}
|
|
} else {
|
|
query = `
|
|
SELECT id, function_id, status, input, output, error, duration, memory_used,
|
|
container_id, executor_id, created_at, started_at, completed_at
|
|
FROM executions
|
|
ORDER BY created_at DESC LIMIT $1 OFFSET $2`
|
|
args = []interface{}{limit, offset}
|
|
}
|
|
|
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
r.logger.Error("Failed to list executions", zap.Error(err))
|
|
return nil, fmt.Errorf("failed to list executions: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var executions []*domain.FunctionExecution
|
|
for rows.Next() {
|
|
execution := &domain.FunctionExecution{}
|
|
var durationInterval sql.NullString
|
|
|
|
err := rows.Scan(
|
|
&execution.ID, &execution.FunctionID, &execution.Status, &execution.Input,
|
|
&execution.Output, &execution.Error, &durationInterval, &execution.MemoryUsed,
|
|
&execution.ContainerID, &execution.ExecutorID, &execution.CreatedAt,
|
|
&execution.StartedAt, &execution.CompletedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
r.logger.Error("Failed to scan execution", zap.Error(err))
|
|
return nil, fmt.Errorf("failed to scan execution: %w", err)
|
|
}
|
|
|
|
// Convert duration from PostgreSQL interval
|
|
if durationInterval.Valid {
|
|
duration, err := intervalToDuration(durationInterval.String)
|
|
if err != nil {
|
|
r.logger.Warn("Failed to parse duration interval", zap.String("interval", durationInterval.String), zap.Error(err))
|
|
} else {
|
|
execution.Duration = duration
|
|
}
|
|
}
|
|
|
|
executions = append(executions, execution)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("failed to iterate executions: %w", err)
|
|
}
|
|
|
|
return executions, nil
|
|
}
|
|
|
|
func (r *executionRepository) GetByFunctionID(ctx context.Context, functionID uuid.UUID, limit, offset int) ([]*domain.FunctionExecution, error) {
|
|
return r.List(ctx, &functionID, limit, offset)
|
|
}
|
|
|
|
func (r *executionRepository) GetByStatus(ctx context.Context, status domain.ExecutionStatus, limit, offset int) ([]*domain.FunctionExecution, error) {
|
|
query := `
|
|
SELECT id, function_id, status, input, output, error, duration, memory_used,
|
|
container_id, executor_id, created_at, started_at, completed_at
|
|
FROM executions WHERE status = $1
|
|
ORDER BY created_at DESC LIMIT $2 OFFSET $3`
|
|
|
|
rows, err := r.db.QueryContext(ctx, query, status, limit, offset)
|
|
if err != nil {
|
|
r.logger.Error("Failed to get executions by status", zap.String("status", string(status)), zap.Error(err))
|
|
return nil, fmt.Errorf("failed to get executions by status: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var executions []*domain.FunctionExecution
|
|
for rows.Next() {
|
|
execution := &domain.FunctionExecution{}
|
|
var durationInterval sql.NullString
|
|
|
|
err := rows.Scan(
|
|
&execution.ID, &execution.FunctionID, &execution.Status, &execution.Input,
|
|
&execution.Output, &execution.Error, &durationInterval, &execution.MemoryUsed,
|
|
&execution.ContainerID, &execution.ExecutorID, &execution.CreatedAt,
|
|
&execution.StartedAt, &execution.CompletedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
r.logger.Error("Failed to scan execution", zap.Error(err))
|
|
return nil, fmt.Errorf("failed to scan execution: %w", err)
|
|
}
|
|
|
|
// Convert duration from PostgreSQL interval
|
|
if durationInterval.Valid {
|
|
duration, err := intervalToDuration(durationInterval.String)
|
|
if err != nil {
|
|
r.logger.Warn("Failed to parse duration interval", zap.String("interval", durationInterval.String), zap.Error(err))
|
|
} else {
|
|
execution.Duration = duration
|
|
}
|
|
}
|
|
|
|
executions = append(executions, execution)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("failed to iterate executions: %w", err)
|
|
}
|
|
|
|
return executions, nil
|
|
}
|
|
|
|
func (r *executionRepository) UpdateStatus(ctx context.Context, id uuid.UUID, status domain.ExecutionStatus) error {
|
|
query := `UPDATE executions SET status = $2 WHERE id = $1`
|
|
|
|
result, err := r.db.ExecContext(ctx, query, id, status)
|
|
if err != nil {
|
|
r.logger.Error("Failed to update execution status", zap.String("id", id.String()), zap.Error(err))
|
|
return fmt.Errorf("failed to update execution status: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get affected rows: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("execution not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *executionRepository) GetRunningExecutions(ctx context.Context) ([]*domain.FunctionExecution, error) {
|
|
return r.GetByStatus(ctx, domain.StatusRunning, 1000, 0)
|
|
} |