149 lines
3.9 KiB
Go
149 lines
3.9 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/database/postgres"
|
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
|
_ "github.com/lib/pq"
|
|
|
|
"github.com/kms/api-key-service/internal/repository"
|
|
)
|
|
|
|
// PostgresProvider implements the DatabaseProvider interface
|
|
type PostgresProvider struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewPostgresProvider creates a new PostgreSQL database provider
|
|
func NewPostgresProvider(dsn string, maxOpenConns, maxIdleConns int, maxLifetime string) (repository.DatabaseProvider, error) {
|
|
db, err := sql.Open("postgres", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database connection: %w", err)
|
|
}
|
|
|
|
// Set connection pool settings
|
|
db.SetMaxOpenConns(maxOpenConns)
|
|
db.SetMaxIdleConns(maxIdleConns)
|
|
|
|
// Parse and set max lifetime if provided
|
|
if maxLifetime != "" {
|
|
if lifetime, err := time.ParseDuration(maxLifetime); err == nil {
|
|
db.SetConnMaxLifetime(lifetime)
|
|
}
|
|
}
|
|
|
|
// Test the connection
|
|
if err := db.Ping(); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
return &PostgresProvider{db: db}, nil
|
|
}
|
|
|
|
// GetDB returns the underlying database connection
|
|
func (p *PostgresProvider) GetDB() interface{} {
|
|
return p.db
|
|
}
|
|
|
|
// Ping checks the database connection
|
|
func (p *PostgresProvider) Ping(ctx context.Context) error {
|
|
if p.db == nil {
|
|
return fmt.Errorf("database connection is nil")
|
|
}
|
|
|
|
// Check if database is closed
|
|
if err := p.db.PingContext(ctx); err != nil {
|
|
return fmt.Errorf("database ping failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes all database connections
|
|
func (p *PostgresProvider) Close() error {
|
|
return p.db.Close()
|
|
}
|
|
|
|
// BeginTx starts a database transaction
|
|
func (p *PostgresProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
|
|
tx, err := p.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
return &PostgresTransaction{tx: tx}, nil
|
|
}
|
|
|
|
// Migrate runs database migrations
|
|
func (p *PostgresProvider) Migrate(ctx context.Context, migrationPath string) error {
|
|
// Create a separate connection for migrations to avoid interfering with the main connection
|
|
migrationDB, err := sql.Open("postgres", p.getDSN())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open migration database connection: %w", err)
|
|
}
|
|
defer migrationDB.Close()
|
|
|
|
driver, err := postgres.WithInstance(migrationDB, &postgres.Config{})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create postgres driver: %w", err)
|
|
}
|
|
|
|
// Convert relative path to file URL
|
|
absPath, err := filepath.Abs(migrationPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get absolute path: %w", err)
|
|
}
|
|
|
|
m, err := migrate.NewWithDatabaseInstance(
|
|
fmt.Sprintf("file://%s", absPath),
|
|
"postgres",
|
|
driver,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migrate instance: %w", err)
|
|
}
|
|
defer m.Close()
|
|
|
|
// Run migrations
|
|
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
|
return fmt.Errorf("failed to run migrations: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getDSN reconstructs the DSN from the current connection
|
|
// This is a workaround since we don't store the original DSN
|
|
func (p *PostgresProvider) getDSN() string {
|
|
// For now, we'll use the default values from config
|
|
// In a production system, we'd store the original DSN
|
|
return "host=localhost port=5432 user=postgres password=postgres dbname=kms sslmode=disable"
|
|
}
|
|
|
|
// PostgresTransaction implements the TransactionProvider interface
|
|
type PostgresTransaction struct {
|
|
tx *sql.Tx
|
|
}
|
|
|
|
// Commit commits the transaction
|
|
func (t *PostgresTransaction) Commit() error {
|
|
return t.tx.Commit()
|
|
}
|
|
|
|
// Rollback rolls back the transaction
|
|
func (t *PostgresTransaction) Rollback() error {
|
|
return t.tx.Rollback()
|
|
}
|
|
|
|
// GetTx returns the underlying transaction
|
|
func (t *PostgresTransaction) GetTx() interface{} {
|
|
return t.tx
|
|
}
|