Skip to content

Commit

Permalink
sqldb: extract migration into method
Browse files Browse the repository at this point in the history
  • Loading branch information
bhandras committed Jun 28, 2024
1 parent 7a1fb8f commit cd643cf
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 66 deletions.
74 changes: 72 additions & 2 deletions sqldb/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,71 @@ import (
"net/http"
"strings"

"github.com/btcsuite/btclog"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/source/httpfs"
)

// MigrationTarget is a functional option that can be passed to applyMigrations
// to specify a target version to migrate to.
type MigrationTarget func(mig *migrate.Migrate) error

var (
// TargetLatest is a MigrationTarget that migrates to the latest
// version available.
TargetLatest = func(mig *migrate.Migrate) error {
return mig.Up()
}

// TargetVersion is a MigrationTarget that migrates to the given
// version.
TargetVersion = func(version uint) MigrationTarget {
return func(mig *migrate.Migrate) error {
return mig.Migrate(version)
}
}
)

// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
// used to log migrations.
type migrationLogger struct {
log btclog.Logger
}

// Printf is like fmt.Printf. We map this to the target logger based on the
// current log level.
func (m *migrationLogger) Printf(format string, v ...interface{}) {
// Trim trailing newlines from the format.
format = strings.TrimRight(format, "\n")

switch m.log.Level() {
case btclog.LevelTrace:
m.log.Tracef(format, v...)
case btclog.LevelDebug:
m.log.Debugf(format, v...)
case btclog.LevelInfo:
m.log.Infof(format, v...)
case btclog.LevelWarn:
m.log.Warnf(format, v...)
case btclog.LevelError:
m.log.Errorf(format, v...)
case btclog.LevelCritical:
m.log.Criticalf(format, v...)
case btclog.LevelOff:
}
}

// Verbose should return true when verbose logging output is wanted
func (m *migrationLogger) Verbose() bool {
return m.log.Level() <= btclog.LevelDebug
}

// applyMigrations executes all database migration files found in the given file
// system under the given path, using the passed database driver and database
// name.
func applyMigrations(fs fs.FS, driver database.Driver, path,
dbName string) error {
dbName string, targetVersion MigrationTarget) error {

// With the migrate instance open, we'll create a new migration source
// using the embedded file system stored in sqlSchemas. The library
Expand All @@ -37,7 +92,22 @@ func applyMigrations(fs fs.FS, driver database.Driver, path,
if err != nil {
return err
}
err = sqlMigrate.Up()

migrationVersion, _, err := sqlMigrate.Version()
if err != nil && !errors.Is(err, migrate.ErrNilVersion) {
log.Errorf("Unable to determine current migration version: %v",
err)

return err
}

log.Infof("Applying migrations from version=%v", migrationVersion)

// Apply our local logger to the migration instance.
sqlMigrate.Log = &migrationLogger{log}

// Execute the migration based on the target given.
err = targetVersion(sqlMigrate)
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return err
}
Expand Down
76 changes: 45 additions & 31 deletions sqldb/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sqldb

import (
"database/sql"
"fmt"
"net/url"
"path"
"strings"
Expand All @@ -19,6 +20,17 @@ var (
// fully executed yet. So this time needs to be chosen correctly to be
// longer than the longest expected individual test run time.
DefaultPostgresFixtureLifetime = 10 * time.Minute

// postgresSchemaReplacements is a map of schema strings that need to be
// replaced for postgres. This is needed because we write the schemas to
// work with sqlite primarily but in sqlc's own dialect, and postgres
// has some differences.
postgresSchemaReplacements = map[string]string{
"BLOB": "BYTEA",
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
}
)

// replacePasswordInDSN takes a DSN string and returns it with the password
Expand Down Expand Up @@ -98,42 +110,44 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
rawDB.SetMaxIdleConns(maxConns)
rawDB.SetConnMaxLifetime(connIdleLifetime)

if !cfg.SkipMigrations {
// Now that the database is open, populate the database with
// our set of schemas based on our embedded in-memory file
// system.
//
// First, we'll need to open up a new migration instance for
// our current target database: Postgres.
driver, err := postgres_migrate.WithInstance(
rawDB, &postgres_migrate.Config{},
)
if err != nil {
return nil, err
}

postgresFS := newReplacerFS(sqlSchemas, map[string]string{
"BLOB": "BYTEA",
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
})

err = applyMigrations(
postgresFS, driver, "sqlc/migrations", dbName,
)
if err != nil {
return nil, err
}
}

queries := sqlc.New(rawDB)

return &PostgresStore{
s := &PostgresStore{
cfg: cfg,
BaseDB: &BaseDB{
DB: rawDB,
Queries: queries,
},
}, nil
}

// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
err := s.ExecuteMigrations(dbName, TargetLatest)
if err != nil {
return nil, fmt.Errorf("error executing migrations: %w",
err)
}
}

return s, nil
}

// ExecuteMigrations runs migrations for the Postgres database, depending on the
// target given, either all migrations or up to a given version.
func (s *PostgresStore) ExecuteMigrations(dbName string,
target MigrationTarget) error {

driver, err := postgres_migrate.WithInstance(
s.DB, &postgres_migrate.Config{},
)
if err != nil {
return fmt.Errorf("error creating postgres migration: %w", err)
}

// Populate the database with our set of schemas based on our embedded
// in-memory file system.
postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements)
return applyMigrations(
postgresFS, driver, "sqlc/migrations", dbName, target,
)
}
74 changes: 41 additions & 33 deletions sqldb/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ const (
sqliteTxLockImmediate = "_txlock=immediate"
)

var (
// sqliteSchemaReplacements is a map of schema strings that need to be
// replaced for sqlite. This is needed because sqlite doesn't directly
// support the BIGINT type for primary keys, so we need to replace it
// with INTEGER.
sqliteSchemaReplacements = map[string]string{
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
}
)

// SqliteStore is a database store implementation that uses a sqlite backend.
type SqliteStore struct {
cfg *SqliteConfig
Expand Down Expand Up @@ -95,46 +105,44 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
db.SetMaxOpenConns(defaultMaxConns)
db.SetMaxIdleConns(defaultMaxConns)
db.SetConnMaxLifetime(connIdleLifetime)

if !cfg.SkipMigrations {
// Now that the database is open, populate the database with
// our set of schemas based on our embedded in-memory file
// system.
//
// First, we'll need to open up a new migration instance for
// our current target database: sqlite.
driver, err := sqlite_migrate.WithInstance(
db, &sqlite_migrate.Config{},
)
if err != nil {
return nil, err
}

// We use INTEGER PRIMARY KEY for sqlite, because it acts as a
// ROWID alias which is 8 bytes big and also autoincrements.
// It's important to use the ROWID as a primary key because the
// key look ups are almost twice as fast
sqliteFS := newReplacerFS(sqlSchemas, map[string]string{
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
})

err = applyMigrations(
sqliteFS, driver, "sqlc/migrations", "sqlc",
)
if err != nil {
return nil, err
}
}

queries := sqlc.New(db)

return &SqliteStore{
s := &SqliteStore{
cfg: cfg,
BaseDB: &BaseDB{
DB: db,
Queries: queries,
},
}, nil
}

// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
if err := s.ExecuteMigrations(TargetLatest); err != nil {
return nil, fmt.Errorf("error executing migrations: "+
"%w", err)

}
}

return s, nil
}

// ExecuteMigrations runs migrations for the sqlite database, depending on the
// target given, either all migrations or up to a given version.
func (s *SqliteStore) ExecuteMigrations(target MigrationTarget) error {
driver, err := sqlite_migrate.WithInstance(
s.DB, &sqlite_migrate.Config{},
)
if err != nil {
return fmt.Errorf("error creating sqlite migration: %w", err)
}

// Populate the database with our set of schemas based on our embedded
// in-memory file system.
sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements)
return applyMigrations(
sqliteFS, driver, "sqlc/migrations", "sqlite", target,
)
}

// NewTestSqliteDB is a helper function that creates an SQLite database for
Expand Down

0 comments on commit cd643cf

Please sign in to comment.