package database import ( "context" "database/sql" "fmt" "path/filepath" "time" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/lib/pq" "github.com/kms/api-key-service/internal/repository" ) // PostgresProvider implements the DatabaseProvider interface type PostgresProvider struct { db *sql.DB } // NewPostgresProvider creates a new PostgreSQL database provider func NewPostgresProvider(dsn string, maxOpenConns, maxIdleConns int, maxLifetime string) (repository.DatabaseProvider, error) { db, err := sql.Open("postgres", dsn) if err != nil { return nil, fmt.Errorf("failed to open database connection: %w", err) } // Set connection pool settings db.SetMaxOpenConns(maxOpenConns) db.SetMaxIdleConns(maxIdleConns) // Parse and set max lifetime if provided if maxLifetime != "" { if lifetime, err := time.ParseDuration(maxLifetime); err == nil { db.SetConnMaxLifetime(lifetime) } } // Test the connection if err := db.Ping(); err != nil { db.Close() return nil, fmt.Errorf("failed to ping database: %w", err) } return &PostgresProvider{db: db}, nil } // GetDB returns the underlying database connection func (p *PostgresProvider) GetDB() interface{} { return p.db } // Ping checks the database connection func (p *PostgresProvider) Ping(ctx context.Context) error { if p.db == nil { return fmt.Errorf("database connection is nil") } // Check if database is closed if err := p.db.PingContext(ctx); err != nil { return fmt.Errorf("database ping failed: %w", err) } return nil } // Close closes all database connections func (p *PostgresProvider) Close() error { return p.db.Close() } // BeginTx starts a database transaction func (p *PostgresProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) { tx, err := p.db.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("failed to begin transaction: %w", err) } return &PostgresTransaction{tx: tx}, nil } // Migrate runs database migrations func (p *PostgresProvider) Migrate(ctx context.Context, migrationPath string) error { // Create a separate connection for migrations to avoid interfering with the main connection migrationDB, err := sql.Open("postgres", p.getDSN()) if err != nil { return fmt.Errorf("failed to open migration database connection: %w", err) } defer migrationDB.Close() driver, err := postgres.WithInstance(migrationDB, &postgres.Config{}) if err != nil { return fmt.Errorf("failed to create postgres driver: %w", err) } // Convert relative path to file URL absPath, err := filepath.Abs(migrationPath) if err != nil { return fmt.Errorf("failed to get absolute path: %w", err) } m, err := migrate.NewWithDatabaseInstance( fmt.Sprintf("file://%s", absPath), "postgres", driver, ) if err != nil { return fmt.Errorf("failed to create migrate instance: %w", err) } defer m.Close() // Run migrations if err := m.Up(); err != nil && err != migrate.ErrNoChange { return fmt.Errorf("failed to run migrations: %w", err) } return nil } // getDSN reconstructs the DSN from the current connection // This is a workaround since we don't store the original DSN func (p *PostgresProvider) getDSN() string { // For now, we'll use the default values from config // In a production system, we'd store the original DSN return "host=localhost port=5432 user=postgres password=postgres dbname=kms sslmode=disable" } // PostgresTransaction implements the TransactionProvider interface type PostgresTransaction struct { tx *sql.Tx } // Commit commits the transaction func (t *PostgresTransaction) Commit() error { return t.tx.Commit() } // Rollback rolls back the transaction func (t *PostgresTransaction) Rollback() error { return t.tx.Rollback() } // GetTx returns the underlying transaction func (t *PostgresTransaction) GetTx() interface{} { return t.tx }