291 lines
6.6 KiB
Go
291 lines
6.6 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/kms/api-key-service/internal/domain"
|
|
"github.com/kms/api-key-service/internal/repository"
|
|
)
|
|
|
|
// StaticTokenRepository implements the StaticTokenRepository interface for PostgreSQL
|
|
type StaticTokenRepository struct {
|
|
db repository.DatabaseProvider
|
|
}
|
|
|
|
// NewStaticTokenRepository creates a new PostgreSQL static token repository
|
|
func NewStaticTokenRepository(db repository.DatabaseProvider) repository.StaticTokenRepository {
|
|
return &StaticTokenRepository{db: db}
|
|
}
|
|
|
|
// Create creates a new static token
|
|
func (r *StaticTokenRepository) Create(ctx context.Context, token *domain.StaticToken) error {
|
|
query := `
|
|
INSERT INTO static_tokens (
|
|
id, app_id, owner_type, owner_name, owner_owner,
|
|
key_hash, type, created_at, updated_at
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
now := time.Now()
|
|
|
|
_, err := db.ExecContext(ctx, query,
|
|
token.ID,
|
|
token.AppID,
|
|
string(token.Owner.Type),
|
|
token.Owner.Name,
|
|
token.Owner.Owner,
|
|
token.KeyHash,
|
|
string(token.Type),
|
|
now,
|
|
now,
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create static token: %w", err)
|
|
}
|
|
|
|
token.CreatedAt = now
|
|
token.UpdatedAt = now
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetByID retrieves a static token by its ID
|
|
func (r *StaticTokenRepository) GetByID(ctx context.Context, tokenID uuid.UUID) (*domain.StaticToken, error) {
|
|
query := `
|
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
|
key_hash, type, created_at, updated_at
|
|
FROM static_tokens
|
|
WHERE id = $1
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
row := db.QueryRowContext(ctx, query, tokenID)
|
|
|
|
token := &domain.StaticToken{}
|
|
var ownerType, ownerName, ownerOwner string
|
|
|
|
err := row.Scan(
|
|
&token.ID,
|
|
&token.AppID,
|
|
&ownerType,
|
|
&ownerName,
|
|
&ownerOwner,
|
|
&token.KeyHash,
|
|
&token.Type,
|
|
&token.CreatedAt,
|
|
&token.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("static token with ID '%s' not found", tokenID)
|
|
}
|
|
return nil, fmt.Errorf("failed to get static token: %w", err)
|
|
}
|
|
|
|
token.Owner = domain.Owner{
|
|
Type: domain.OwnerType(ownerType),
|
|
Name: ownerName,
|
|
Owner: ownerOwner,
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// GetByKeyHash retrieves a static token by its key hash
|
|
func (r *StaticTokenRepository) GetByKeyHash(ctx context.Context, keyHash string) (*domain.StaticToken, error) {
|
|
query := `
|
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
|
key_hash, type, created_at, updated_at
|
|
FROM static_tokens
|
|
WHERE key_hash = $1
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
row := db.QueryRowContext(ctx, query, keyHash)
|
|
|
|
token := &domain.StaticToken{}
|
|
var ownerType, ownerName, ownerOwner string
|
|
|
|
err := row.Scan(
|
|
&token.ID,
|
|
&token.AppID,
|
|
&ownerType,
|
|
&ownerName,
|
|
&ownerOwner,
|
|
&token.KeyHash,
|
|
&token.Type,
|
|
&token.CreatedAt,
|
|
&token.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("static token with hash not found")
|
|
}
|
|
return nil, fmt.Errorf("failed to get static token by hash: %w", err)
|
|
}
|
|
|
|
token.Owner = domain.Owner{
|
|
Type: domain.OwnerType(ownerType),
|
|
Name: ownerName,
|
|
Owner: ownerOwner,
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// GetByAppID retrieves all static tokens for an application
|
|
func (r *StaticTokenRepository) GetByAppID(ctx context.Context, appID string) ([]*domain.StaticToken, error) {
|
|
query := `
|
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
|
key_hash, type, created_at, updated_at
|
|
FROM static_tokens
|
|
WHERE app_id = $1
|
|
ORDER BY created_at DESC
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
rows, err := db.QueryContext(ctx, query, appID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tokens []*domain.StaticToken
|
|
for rows.Next() {
|
|
token := &domain.StaticToken{}
|
|
var ownerType, ownerName, ownerOwner string
|
|
|
|
err := rows.Scan(
|
|
&token.ID,
|
|
&token.AppID,
|
|
&ownerType,
|
|
&ownerName,
|
|
&ownerOwner,
|
|
&token.KeyHash,
|
|
&token.Type,
|
|
&token.CreatedAt,
|
|
&token.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
|
}
|
|
|
|
token.Owner = domain.Owner{
|
|
Type: domain.OwnerType(ownerType),
|
|
Name: ownerName,
|
|
Owner: ownerOwner,
|
|
}
|
|
|
|
tokens = append(tokens, token)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// List retrieves static tokens with pagination
|
|
func (r *StaticTokenRepository) List(ctx context.Context, limit, offset int) ([]*domain.StaticToken, error) {
|
|
query := `
|
|
SELECT id, app_id, owner_type, owner_name, owner_owner,
|
|
key_hash, type, created_at, updated_at
|
|
FROM static_tokens
|
|
ORDER BY created_at DESC
|
|
LIMIT $1 OFFSET $2
|
|
`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
rows, err := db.QueryContext(ctx, query, limit, offset)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query static tokens: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tokens []*domain.StaticToken
|
|
for rows.Next() {
|
|
token := &domain.StaticToken{}
|
|
var ownerType, ownerName, ownerOwner string
|
|
|
|
err := rows.Scan(
|
|
&token.ID,
|
|
&token.AppID,
|
|
&ownerType,
|
|
&ownerName,
|
|
&ownerOwner,
|
|
&token.KeyHash,
|
|
&token.Type,
|
|
&token.CreatedAt,
|
|
&token.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan static token: %w", err)
|
|
}
|
|
|
|
token.Owner = domain.Owner{
|
|
Type: domain.OwnerType(ownerType),
|
|
Name: ownerName,
|
|
Owner: ownerOwner,
|
|
}
|
|
|
|
tokens = append(tokens, token)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating static tokens: %w", err)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// Delete deletes a static token
|
|
func (r *StaticTokenRepository) Delete(ctx context.Context, tokenID uuid.UUID) error {
|
|
query := `DELETE FROM static_tokens WHERE id = $1`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
result, err := db.ExecContext(ctx, query, tokenID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete static token: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("static token with ID '%s' not found", tokenID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Exists checks if a static token exists
|
|
func (r *StaticTokenRepository) Exists(ctx context.Context, tokenID uuid.UUID) (bool, error) {
|
|
query := `SELECT 1 FROM static_tokens WHERE id = $1`
|
|
|
|
db := r.db.GetDB().(*sql.DB)
|
|
var exists int
|
|
err := db.QueryRowContext(ctx, query, tokenID).Scan(&exists)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("failed to check static token existence: %w", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|