-
This commit is contained in:
@ -4,12 +4,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/kms/api-key-service/internal/repository"
|
||||
@ -17,7 +13,8 @@ import (
|
||||
|
||||
// PostgresProvider implements the DatabaseProvider interface
|
||||
type PostgresProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dsn string
|
||||
}
|
||||
|
||||
// NewPostgresProvider creates a new PostgreSQL database provider
|
||||
@ -44,7 +41,7 @@ func NewPostgresProvider(dsn string, maxOpenConns, maxIdleConns int, maxLifetime
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &PostgresProvider{db: db}, nil
|
||||
return &PostgresProvider{db: db, dsn: dsn}, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying database connection
|
||||
@ -81,51 +78,7 @@ func (p *PostgresProvider) BeginTx(ctx context.Context) (repository.TransactionP
|
||||
return &PostgresTransaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
// Migrate runs database migrations
|
||||
func (p *PostgresProvider) Migrate(ctx context.Context, migrationPath string) error {
|
||||
// Create a separate connection for migrations to avoid interfering with the main connection
|
||||
migrationDB, err := sql.Open("postgres", p.getDSN())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open migration database connection: %w", err)
|
||||
}
|
||||
defer migrationDB.Close()
|
||||
|
||||
driver, err := postgres.WithInstance(migrationDB, &postgres.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create postgres driver: %w", err)
|
||||
}
|
||||
|
||||
// Convert relative path to file URL
|
||||
absPath, err := filepath.Abs(migrationPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
fmt.Sprintf("file://%s", absPath),
|
||||
"postgres",
|
||||
driver,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrate instance: %w", err)
|
||||
}
|
||||
defer m.Close()
|
||||
|
||||
// Run migrations
|
||||
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDSN reconstructs the DSN from the current connection
|
||||
// This is a workaround since we don't store the original DSN
|
||||
func (p *PostgresProvider) getDSN() string {
|
||||
// For now, we'll use the default values from config
|
||||
// In a production system, we'd store the original DSN
|
||||
return "host=localhost port=5432 user=postgres password=postgres dbname=kms sslmode=disable"
|
||||
}
|
||||
|
||||
// PostgresTransaction implements the TransactionProvider interface
|
||||
type PostgresTransaction struct {
|
||||
|
||||
Reference in New Issue
Block a user