Files
skybridge/kms/internal/database/postgres.go
2025-08-26 19:29:41 -04:00

102 lines
2.3 KiB
Go

package database
import (
"context"
"database/sql"
"fmt"
"time"
_ "github.com/lib/pq"
"github.com/RyanCopley/skybridge/kms/internal/repository"
)
// PostgresProvider implements the DatabaseProvider interface
type PostgresProvider struct {
db *sql.DB
dsn string
}
// NewPostgresProvider creates a new PostgreSQL database provider
func NewPostgresProvider(dsn string, maxOpenConns, maxIdleConns int, maxLifetime string) (repository.DatabaseProvider, error) {
db, err := sql.Open("postgres", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open database connection: %w", err)
}
// Set connection pool settings
db.SetMaxOpenConns(maxOpenConns)
db.SetMaxIdleConns(maxIdleConns)
// Parse and set max lifetime if provided
if maxLifetime != "" {
if lifetime, err := time.ParseDuration(maxLifetime); err == nil {
db.SetConnMaxLifetime(lifetime)
}
}
// Test the connection
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &PostgresProvider{db: db, dsn: dsn}, nil
}
// GetDB returns the underlying database connection
func (p *PostgresProvider) GetDB() interface{} {
return p.db
}
// Ping checks the database connection
func (p *PostgresProvider) Ping(ctx context.Context) error {
if p.db == nil {
return fmt.Errorf("database connection is nil")
}
// Check if database is closed
if err := p.db.PingContext(ctx); err != nil {
return fmt.Errorf("database ping failed: %w", err)
}
return nil
}
// Close closes all database connections
func (p *PostgresProvider) Close() error {
return p.db.Close()
}
// BeginTx starts a database transaction
func (p *PostgresProvider) BeginTx(ctx context.Context) (repository.TransactionProvider, error) {
tx, err := p.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
return &PostgresTransaction{tx: tx}, nil
}
// PostgresTransaction implements the TransactionProvider interface
type PostgresTransaction struct {
tx *sql.Tx
}
// Commit commits the transaction
func (t *PostgresTransaction) Commit() error {
return t.tx.Commit()
}
// Rollback rolls back the transaction
func (t *PostgresTransaction) Rollback() error {
return t.tx.Rollback()
}
// GetTx returns the underlying transaction
func (t *PostgresTransaction) GetTx() interface{} {
return t.tx
}