diff --git a/create.go b/create.go index 1a8bb74c0..fa10993e4 100644 --- a/create.go +++ b/create.go @@ -10,25 +10,38 @@ import ( ) type tmplVars struct { - Version string - CamelName string + Version string + CamelName string + PackageName string + ProviderVar string } -var ( - sequential = false -) - // SetSequential set whether to use sequential versioning instead of timestamp based versioning func SetSequential(s bool) { - sequential = s + defaultProvider.SetSequential(s) } -// Create writes a new blank migration file. +// SetSequential set's whether to use sequential versioning instead of timestamp based versioning +func (p *Provider) SetSequential(s bool) { p.sequential = s } + +// CreateWithTemplate writes a new blank migration file. func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, migrationType string) error { - var version string - if sequential { + return defaultProvider.CreateWithTemplate(db, dir, tmpl, name, migrationType) +} + +// CreateWithTemplate writes a new blank migration file. +func (p *Provider) CreateWithTemplate(_ *sql.DB, dir string, tmpl *template.Template, name, migrationType string) error { + timefn := p.timeFn + if p.timeFn == nil { + timefn = time.Now + } + version := timefn().Format(p.timestampFormat) + if p.baseDir != "" && (dir == "" || dir == ".") { + dir = p.baseDir + } + if p.sequential { // always use DirFS here because it's modifying operation - migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) + migrations, err := p.collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) if err != nil { return err } @@ -43,8 +56,6 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m } else { version = fmt.Sprintf(seqVersionTemplate, int64(1)) } - } else { - version = time.Now().Format(timestampFormat) } filename := fmt.Sprintf("%v_%v.%v", version, snakeCase(name), migrationType) @@ -69,20 +80,27 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m defer f.Close() vars := tmplVars{ - Version: version, - CamelName: camelCase(name), + PackageName: p.packageName, + ProviderVar: p.providerVarName, + Version: version, + CamelName: camelCase(name), } if err := tmpl.Execute(f, vars); err != nil { return fmt.Errorf("failed to execute tmpl: %w", err) } - log.Printf("Created new file: %s\n", f.Name()) + p.log.Printf("Created new file: %s\n", f.Name()) return nil } // Create writes a new blank migration file. func Create(db *sql.DB, dir, name, migrationType string) error { - return CreateWithTemplate(db, dir, nil, name, migrationType) + return defaultProvider.Create(db, dir, name, migrationType) +} + +// Create writes a new blank migration file. +func (p *Provider) Create(db *sql.DB, dir, name, migrationType string) error { + return p.CreateWithTemplate(db, dir, nil, name, migrationType) } var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`-- +goose Up @@ -96,15 +114,19 @@ SELECT 'down SQL query'; -- +goose StatementEnd `)) -var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations +var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package {{.PackageName}} import ( "database/sql" - "github.com/pressly/goose/v3" +{{if eq .ProviderVar ""}} "github.com/pressly/goose/v3" {{end}} ) func init() { +{{- if eq .ProviderVar "" }} goose.AddMigration(up{{.CamelName}}, down{{.CamelName}}) +{{- else }} + {{.ProviderVar}}.AddMigration(up{{.CamelName}}, down{{.CamelName}}) +{{end }} } func up{{.CamelName}}(tx *sql.Tx) error { diff --git a/db.go b/db.go index de9dd5e23..3670c0dec 100644 --- a/db.go +++ b/db.go @@ -8,7 +8,13 @@ import ( // OpenDBWithDriver creates a connection to a database, and modifies goose // internals to be compatible with the supplied driver by calling SetDialect. func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) { - if err := SetDialect(driver); err != nil { + return defaultProvider.OpenDBWithDriver(driver, dbstring) +} + +// OpenDBWithDriver creates a connection to a database, and modifies goose +// internals to be compatible with the supplied driver by calling SetDialect. +func (p *Provider) OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) { + if err := p.SetDialect(driver); err != nil { return nil, err } diff --git a/dialect.go b/dialect.go index b9b6245e5..37ae11b53 100644 --- a/dialect.go +++ b/dialect.go @@ -5,9 +5,20 @@ import ( "fmt" ) +const ( + DialectPostgres = "postgres" + DialectSQLite3 = "sqlite3" + DialectMySQL = "mysql" + DialectMSSQL = "mssql" + DialectRedShit = "redshift" + DialectTiDB = "tidb" + DialectClickHouse = "clickhouse" +) + // SQLDialect abstracts the details of specific SQL dialects // for goose's few SQL specific statements type SQLDialect interface { + SetTableName(name string) // set table name to use for SQL generation createVersionTableSQL() string // sql string to create the db version table insertVersionSQL() string // sql string to insert the initial version table row deleteVersionSQL() string // sql string to delete version @@ -15,60 +26,85 @@ type SQLDialect interface { dbVersionQuery(db *sql.DB) (*sql.Rows, error) } -var dialect SQLDialect = &PostgresDialect{} - // GetDialect gets the SQLDialect func GetDialect() SQLDialect { - return dialect + return defaultProvider.dialect } -// SetDialect sets the SQLDialect -func SetDialect(d string) error { +func SelectDialect(tableName, d string) (SQLDialect, error) { + + base := BaseDialect{TableName: tableName} + switch d { - case "postgres", "pgx": - dialect = &PostgresDialect{} - case "mysql": - dialect = &MySQLDialect{} - case "sqlite3", "sqlite": - dialect = &Sqlite3Dialect{} - case "mssql": - dialect = &SqlServerDialect{} - case "redshift": - dialect = &RedshiftDialect{} - case "tidb": - dialect = &TiDBDialect{} - case "clickhouse": - dialect = &ClickHouseDialect{} + case DialectPostgres, "pgx": + return &PostgresDialect{base}, nil + case DialectMySQL: + return &MySQLDialect{base}, nil + case DialectSQLite3, "sqlite": + return &Sqlite3Dialect{base}, nil + case DialectMSSQL: + return &SqlServerDialect{base}, nil + case DialectRedShit: + return &RedshiftDialect{base}, nil + case DialectTiDB: + return &TiDBDialect{base}, nil + case DialectClickHouse: + return &ClickHouseDialect{base}, nil default: - return fmt.Errorf("%q: unknown dialect", d) + return nil, fmt.Errorf("%q: unknown dialect", d) } +} + +// Dialect returns the SQLDialect of the provider +func (p *Provider) Dialect() SQLDialect { return p.dialect } + +// SetDialect sets the SQLDialect +func SetDialect(d string) error { + return defaultProvider.SetDialect(d) +} +// SetDialect sets the SQLDialect +func (p *Provider) SetDialect(d string) error { + dialect, err := SelectDialect(p.tableName, d) + if err != nil { + return err + } + p.dialect = dialect return nil } +// BaseDialect struct. +type BaseDialect struct { + TableName string +} + +func (bd *BaseDialect) SetTableName(name string) { + bd.TableName = name +} + //////////////////////////// // Postgres //////////////////////////// // PostgresDialect struct. -type PostgresDialect struct{} +type PostgresDialect struct{ BaseDialect } -func (pg PostgresDialect) createVersionTableSQL() string { +func (d PostgresDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) - );`, TableName()) + );`, d.TableName) } -func (pg PostgresDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) +func (d PostgresDialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", d.TableName) } -func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) +func (d PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", d.TableName)) if err != nil { return nil, err } @@ -76,12 +112,12 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m PostgresDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) +func (d PostgresDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", d.TableName) } -func (pg PostgresDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) +func (d PostgresDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", d.TableName) } //////////////////////////// @@ -89,24 +125,24 @@ func (pg PostgresDialect) deleteVersionSQL() string { //////////////////////////// // MySQLDialect struct. -type MySQLDialect struct{} +type MySQLDialect struct{ BaseDialect } -func (m MySQLDialect) createVersionTableSQL() string { +func (d MySQLDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) - );`, TableName()) + );`, d.TableName) } -func (m MySQLDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) +func (d MySQLDialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", d.TableName) } -func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) +func (d MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", d.TableName)) if err != nil { return nil, err } @@ -114,12 +150,12 @@ func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m MySQLDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) +func (d MySQLDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", d.TableName) } -func (m MySQLDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +func (d MySQLDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", d.TableName) } //////////////////////////// @@ -127,23 +163,23 @@ func (m MySQLDialect) deleteVersionSQL() string { //////////////////////////// // SqlServerDialect struct. -type SqlServerDialect struct{} +type SqlServerDialect struct{ BaseDialect } -func (m SqlServerDialect) createVersionTableSQL() string { +func (d SqlServerDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, version_id BIGINT NOT NULL, is_applied BIT NOT NULL, tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP - );`, TableName()) + );`, d.TableName) } -func (m SqlServerDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName()) +func (d SqlServerDialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", d.TableName) } -func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName())) +func (d SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", d.TableName)) if err != nil { return nil, err } @@ -151,7 +187,7 @@ func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m SqlServerDialect) migrationSQL() string { +func (d SqlServerDialect) migrationSQL() string { const tpl = ` WITH Migrations AS ( @@ -165,11 +201,11 @@ FROM Migrations WHERE RowNumber BETWEEN 1 AND 2 ORDER BY tstamp DESC ` - return fmt.Sprintf(tpl, TableName()) + return fmt.Sprintf(tpl, d.TableName) } -func (m SqlServerDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName()) +func (d SqlServerDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", d.TableName) } //////////////////////////// @@ -177,23 +213,23 @@ func (m SqlServerDialect) deleteVersionSQL() string { //////////////////////////// // Sqlite3Dialect struct. -type Sqlite3Dialect struct{} +type Sqlite3Dialect struct{ BaseDialect } -func (m Sqlite3Dialect) createVersionTableSQL() string { +func (d Sqlite3Dialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id INTEGER PRIMARY KEY AUTOINCREMENT, version_id INTEGER NOT NULL, is_applied INTEGER NOT NULL, tstamp TIMESTAMP DEFAULT (datetime('now')) - );`, TableName()) + );`, d.TableName) } -func (m Sqlite3Dialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) +func (d Sqlite3Dialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", d.TableName) } -func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) +func (d Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", d.TableName)) if err != nil { return nil, err } @@ -201,12 +237,12 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m Sqlite3Dialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) +func (d Sqlite3Dialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", d.TableName) } -func (m Sqlite3Dialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +func (d Sqlite3Dialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", d.TableName) } //////////////////////////// @@ -214,24 +250,24 @@ func (m Sqlite3Dialect) deleteVersionSQL() string { //////////////////////////// // RedshiftDialect struct. -type RedshiftDialect struct{} +type RedshiftDialect struct{ BaseDialect } -func (rs RedshiftDialect) createVersionTableSQL() string { +func (d RedshiftDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id integer NOT NULL identity(1, 1), version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default sysdate, PRIMARY KEY(id) - );`, TableName()) + );`, d.TableName) } -func (rs RedshiftDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) +func (d RedshiftDialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", d.TableName) } -func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) +func (d RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", d.TableName)) if err != nil { return nil, err } @@ -239,12 +275,12 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m RedshiftDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) +func (d RedshiftDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", d.TableName) } -func (rs RedshiftDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) +func (d RedshiftDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", d.TableName) } //////////////////////////// @@ -252,24 +288,24 @@ func (rs RedshiftDialect) deleteVersionSQL() string { //////////////////////////// // TiDBDialect struct. -type TiDBDialect struct{} +type TiDBDialect struct{ BaseDialect } -func (m TiDBDialect) createVersionTableSQL() string { +func (d TiDBDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) - );`, TableName()) + );`, d.TableName) } -func (m TiDBDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) +func (d TiDBDialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", d.TableName) } -func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) +func (d TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", d.TableName)) if err != nil { return nil, err } @@ -277,12 +313,12 @@ func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m TiDBDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) +func (d TiDBDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", d.TableName) } -func (m TiDBDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +func (d TiDBDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", d.TableName) } //////////////////////////// @@ -290,33 +326,33 @@ func (m TiDBDialect) deleteVersionSQL() string { //////////////////////////// // ClickHouseDialect struct. -type ClickHouseDialect struct{} +type ClickHouseDialect struct{ BaseDialect } -func (m ClickHouseDialect) createVersionTableSQL() string { +func (d ClickHouseDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( version_id Int64, is_applied UInt8, date Date default now(), tstamp DateTime default now() - ) Engine = MergeTree(date, (date), 8192)`, TableName()) + ) Engine = MergeTree(date, (date), 8192)`, d.TableName) } -func (m ClickHouseDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY version_id DESC", TableName())) +func (d ClickHouseDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY tstamp DESC", d.TableName)) if err != nil { return nil, err } return rows, err } -func (m ClickHouseDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)", TableName()) +func (d ClickHouseDialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)", d.TableName) } -func (m ClickHouseDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1", TableName()) +func (d ClickHouseDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1", d.TableName) } -func (m ClickHouseDialect) deleteVersionSQL() string { - return fmt.Sprintf("ALTER TABLE %s DELETE WHERE version_id = $1", TableName()) +func (d ClickHouseDialect) deleteVersionSQL() string { + return fmt.Sprintf("ALTER TABLE %s DELETE WHERE version_id = $1", d.TableName) } diff --git a/down.go b/down.go index c58c2144c..7fac3bff4 100644 --- a/down.go +++ b/down.go @@ -7,11 +7,13 @@ import ( // Down rolls back a single migration from the current version. func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { - option := &options{} - for _, f := range opts { - f(option) - } - migrations, err := CollectMigrations(dir, minVersion, maxVersion) + return defaultProvider.Down(db, dir, opts...) +} + +// Down rolls back a single migration from the current version. +func (p *Provider) Down(db *sql.DB, dir string, opts ...OptionsFunc) error { + option := applyOptions(opts) + migrations, err := p.CollectMigrations(dir, minVersion, maxVersion) if err != nil { return err } @@ -21,9 +23,9 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { } currentVersion := migrations[len(migrations)-1].Version // Migrate only the latest migration down. - return downToNoVersioning(db, migrations, currentVersion-1) + return downToNoVersioning(p, db, migrations, currentVersion-1) } - currentVersion, err := GetDBVersion(db) + currentVersion, err := p.GetDBVersion(db) if err != nil { return err } @@ -31,45 +33,47 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { if err != nil { return fmt.Errorf("no migration %v", currentVersion) } - return current.Down(db) + return current.DownWithProvider(p, db) } // DownTo rolls back migrations to a specific version. func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { - option := &options{} - for _, f := range opts { - f(option) - } - migrations, err := CollectMigrations(dir, minVersion, maxVersion) + return defaultProvider.DownTo(db, dir, version, opts...) +} + +// DownTo rolls back migrations to a specific version. +func (p *Provider) DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { + option := applyOptions(opts) + migrations, err := p.CollectMigrations(dir, minVersion, maxVersion) if err != nil { return err } if option.noVersioning { - return downToNoVersioning(db, migrations, version) + return downToNoVersioning(p, db, migrations, version) } for { - currentVersion, err := GetDBVersion(db) + currentVersion, err := p.GetDBVersion(db) if err != nil { return err } if currentVersion == 0 { - log.Printf("goose: no migrations to run. current version: %d\n", currentVersion) + p.log.Printf("goose: no migrations to run. current version: %d\n", currentVersion) return nil } current, err := migrations.Current(currentVersion) if err != nil { - log.Printf("goose: migration file not found for current version (%d), error: %s\n", currentVersion, err) + p.log.Printf("goose: migration file not found for current version (%d), error: %s\n", currentVersion, err) return err } if current.Version <= version { - log.Printf("goose: no migrations to run. current version: %d\n", currentVersion) + p.log.Printf("goose: no migrations to run. current version: %d\n", currentVersion) return nil } - if err = current.Down(db); err != nil { + if err = current.DownWithProvider(p, db); err != nil { return err } } @@ -77,7 +81,10 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // downToNoVersioning applies down migrations down to, but not including, the // target version. -func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func downToNoVersioning(p *Provider, db *sql.DB, migrations Migrations, version int64) error { + if p == nil { + p = defaultProvider + } var finalVersion int64 for i := len(migrations) - 1; i >= 0; i-- { if version >= migrations[i].Version { @@ -85,10 +92,10 @@ func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error break } migrations[i].noVersioning = true - if err := migrations[i].Down(db); err != nil { + if err := migrations[i].DownWithProvider(p, db); err != nil { return err } } - log.Printf("goose: down to current file version: %d\n", finalVersion) + p.log.Printf("goose: down to current file version: %d\n", finalVersion) return nil } diff --git a/examples/multi-migrations/testdata/migrations/.keep b/examples/multi-migrations/testdata/migrations/.keep new file mode 100644 index 000000000..e69de29bb diff --git a/fix.go b/fix.go index 7bc7ed5d6..6112252d0 100644 --- a/fix.go +++ b/fix.go @@ -9,9 +9,14 @@ import ( const seqVersionTemplate = "%05v" -func Fix(dir string) error { +func Fix(dir string) error { return defaultProvider.Fix(dir) } + +func (p *Provider) Fix(dir string) error { + if p.baseDir != "" && (dir == "" || dir == ".") { + dir = p.baseDir + } // always use osFS here because it's modifying operation - migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) + migrations, err := p.collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) if err != nil { return err } @@ -32,13 +37,17 @@ func Fix(dir string) error { version = last.Version + 1 } + seqVerTemplate := p.seqVersionTemplate + if seqVerTemplate == "" { + seqVerTemplate = seqVersionTemplate + } // fix filenames by replacing timestamps with sequential versions for _, tsm := range tsMigrations { oldPath := tsm.Source newPath := strings.Replace( oldPath, fmt.Sprintf("%d", tsm.Version), - fmt.Sprintf(seqVersionTemplate, version), + fmt.Sprintf(seqVerTemplate, version), 1, ) @@ -46,7 +55,7 @@ func Fix(dir string) error { return err } - log.Printf("RENAMED %s => %s", filepath.Base(oldPath), filepath.Base(newPath)) + p.log.Printf("RENAMED %s => %s", filepath.Base(oldPath), filepath.Base(newPath)) version++ } diff --git a/goose.go b/goose.go index bcf1df48f..5c1469571 100644 --- a/goose.go +++ b/goose.go @@ -14,26 +14,33 @@ var ( minVersion = int64(0) maxVersion = int64((1 << 63) - 1) timestampFormat = "20060102150405" - verbose = false - - // base fs to lookup migrations - baseFS fs.FS = osFS{} ) // SetVerbose set the goose verbosity mode func SetVerbose(v bool) { - verbose = v + defaultProvider.SetVerbose(v) } +// SetVerbose set the goose verbosity mode +func (p *Provider) SetVerbose(v bool) { p.verbose = v } + // SetBaseFS sets a base FS to discover migrations. It can be used with 'embed' package. // Calling with 'nil' argument leads to default behaviour: discovering migrations from os filesystem. // Note that modifying operations like Create will use os filesystem anyway. func SetBaseFS(fsys fs.FS) { + defaultProvider.SetBaseFS(fsys) +} + +// SetBaseFS sets a base FS to discover migrations. It can be used with the `embed` package. +// Calling with `nil` argument leads to the default behavior: discovering migrations from the os +// filesystem. +// +// Note: that modifying operations like Create will use os filesystem anyway. +func (p *Provider) SetBaseFS(fsys fs.FS) { if fsys == nil { fsys = osFS{} } - - baseFS = fsys + p.baseFS = fsys } // Run runs a goose command. @@ -41,7 +48,7 @@ func Run(command string, db *sql.DB, dir string, args ...string) error { return run(command, db, dir, args) } -// Run runs a goose command with options. +// RunWithOptions runs a goose command with options. func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { return run(command, db, dir, args, options...) } diff --git a/log.go b/log.go index 7f531a270..da8309088 100644 --- a/log.go +++ b/log.go @@ -17,9 +17,12 @@ type Logger interface { // SetLogger sets the logger for package output func SetLogger(l Logger) { - log = l + defaultProvider.SetLogger(l) } +// SetLogger sets the logger for package output +func (p *Provider) SetLogger(l Logger) { p.log = l } + // stdLogger is a default logger that outputs to a stdlib's log.std logger. type stdLogger struct{} diff --git a/migrate.go b/migrate.go index 91da07dcc..6a6ff7102 100644 --- a/migrate.go +++ b/migrate.go @@ -18,8 +18,6 @@ var ( ErrNoNextVersion = errors.New("no next version found") // MaxVersion is the maximum allowed version. MaxVersion int64 = 9223372036854775807 // max(int64) - - registeredGoMigrations = map[int64]*Migration{} ) // Migrations slice. @@ -83,7 +81,7 @@ func (ms Migrations) versioned() (Migrations, error) { // assume that the user will never have more than 19700101000000 migrations for _, m := range ms { - // parse version as timestmap + // parse version as timestamp versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version)) if versionTime.Before(time.Unix(0, 0)) || err != nil { @@ -100,7 +98,7 @@ func (ms Migrations) timestamped() (Migrations, error) { // assume that the user will never have more than 19700101000000 migrations for _, m := range ms { - // parse version as timestmap + // parse version as timestamp versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version)) if err != nil { // probably not a timestamp @@ -128,19 +126,30 @@ func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { AddNamedMigration(filename, up, down) } +func (p *Provider) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { + _, filename, _, _ := runtime.Caller(1) + p.AddNamedMigration(filename, up, down) +} + // AddNamedMigration : Add a named migration. func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { + defaultProvider.AddNamedMigration(filename, up, down) + return +} + +// AddNamedMigration : Add a named migration. +func (p *Provider) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { v, _ := NumericComponent(filename) migration := &Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename} - if existing, ok := registeredGoMigrations[v]; ok { + if existing, ok := p.registeredGoMigrations[v]; ok { panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source)) } - registeredGoMigrations[v] = migration + p.registeredGoMigrations[v] = migration } -func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) { +func (p *Provider) collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) { if _, err := fs.Stat(fsys, dirpath); errors.Is(err, fs.ErrNotExist) { return nil, fmt.Errorf("%s directory does not exist", dirpath) } @@ -164,7 +173,8 @@ func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Mig } // Go migrations registered via goose.AddMigration(). - for _, migration := range registeredGoMigrations { + for _, migration := range p.registeredGoMigrations { + p.verboseInfo("registered go Migration: ", migration.Source) v, err := NumericComponent(migration.Source) if err != nil { return nil, fmt.Errorf("could not parse go migration file %q: %w", migration.Source, err) @@ -186,7 +196,7 @@ func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Mig } // Skip migrations already existing migrations registered via goose.AddMigration(). - if _, ok := registeredGoMigrations[v]; ok { + if _, ok := p.registeredGoMigrations[v]; ok { continue } @@ -203,8 +213,12 @@ func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Mig // CollectMigrations returns all the valid looking migration scripts in the // migrations folder and go func registry, and key them by version. -func CollectMigrations(dirpath string, current, target int64) (Migrations, error) { - return collectMigrationsFS(baseFS, dirpath, current, target) +func CollectMigrations(dirPath string, current, target int64) (Migrations, error) { + return defaultProvider.collectMigrationsFS(defaultProvider.baseFS, dirPath, current, target) +} + +func (p *Provider) CollectMigrations(dirPath string, current, target int64) (Migrations, error) { + return p.collectMigrationsFS(p.baseFS, dirPath, current, target) } func sortAndConnectMigrations(migrations Migrations) Migrations { @@ -240,9 +254,16 @@ func versionFilter(v, current, target int64) bool { // EnsureDBVersion retrieves the current version for this DB. // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { - rows, err := GetDialect().dbVersionQuery(db) + return defaultProvider.EnsureDBVersion(db) +} + +// EnsureDBVersion retrieves the current version for this DB. +// Create and initialize the DB version table if it doesn't exist. +func (p *Provider) EnsureDBVersion(db *sql.DB) (int64, error) { + dialect := p.dialect + rows, err := dialect.dbVersionQuery(db) if err != nil { - return 0, createVersionTable(db) + return 0, createVersionTable(dialect, db) } defer rows.Close() @@ -284,18 +305,17 @@ func EnsureDBVersion(db *sql.DB) (int64, error) { } return 0, ErrNoNextVersion + } // Create the db version table // and insert the initial 0 value into it -func createVersionTable(db *sql.DB) error { +func createVersionTable(d SQLDialect, db *sql.DB) error { txn, err := db.Begin() if err != nil { return err } - d := GetDialect() - if _, err := txn.Exec(d.createVersionTableSQL()); err != nil { txn.Rollback() return err @@ -313,7 +333,12 @@ func createVersionTable(db *sql.DB) error { // GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error. func GetDBVersion(db *sql.DB) (int64, error) { - version, err := EnsureDBVersion(db) + return defaultProvider.GetDBVersion(db) +} + +// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error. +func (p *Provider) GetDBVersion(db *sql.DB) (int64, error) { + version, err := p.EnsureDBVersion(db) if err != nil { return -1, err } diff --git a/migration.go b/migration.go index 775dc779c..05e42366d 100644 --- a/migration.go +++ b/migration.go @@ -34,43 +34,51 @@ func (m *Migration) String() string { } // Up runs an up migration. +// Deprecated: please use UpWithFS func (m *Migration) Up(db *sql.DB) error { - if err := m.run(db, true); err != nil { - return err - } - return nil + return m.UpWithProvider(defaultProvider, db) +} + +func (m *Migration) UpWithProvider(p *Provider, db *sql.DB) error { + return m.run(p, db, true) } // Down runs a down migration. +// Deprecated: please use DownWithFS func (m *Migration) Down(db *sql.DB) error { - if err := m.run(db, false); err != nil { - return err - } - return nil + return m.DownWithProvider(defaultProvider, db) +} + +func (m *Migration) DownWithProvider(p *Provider, db *sql.DB) error { + return m.run(p, db, false) } -func (m *Migration) run(db *sql.DB, direction bool) error { +func (m *Migration) run(p *Provider, db *sql.DB, direction bool) error { + if p == nil { + p = defaultProvider + } + switch filepath.Ext(m.Source) { case ".sql": - f, err := baseFS.Open(m.Source) + f, err := p.baseFS.Open(m.Source) if err != nil { return fmt.Errorf("ERROR %v: failed to open SQL migration file: %w", filepath.Base(m.Source), err) } defer f.Close() - statements, useTx, err := parseSQLMigration(f, direction) + statements, useTx, err := parseSQLMigration(p, f, direction) if err != nil { return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err) } - if err := runSQLMigration(db, statements, useTx, m.Version, direction, m.noVersioning); err != nil { + if err := runSQLMigration(p, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil { return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err) } if len(statements) > 0 { - log.Println("OK ", filepath.Base(m.Source)) + p.log.Println("OK ", filepath.Base(m.Source)) } else { - log.Println("EMPTY", filepath.Base(m.Source)) + p.log.Println("EMPTY", filepath.Base(m.Source)) } case ".go": @@ -96,12 +104,12 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } if !m.noVersioning { if direction { - if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil { + if _, err := tx.Exec(p.dialect.insertVersionSQL(), m.Version, direction); err != nil { tx.Rollback() return fmt.Errorf("ERROR failed to execute transaction: %w", err) } } else { - if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil { + if _, err := tx.Exec(p.dialect.deleteVersionSQL(), m.Version); err != nil { tx.Rollback() return fmt.Errorf("ERROR failed to execute transaction: %w", err) } @@ -113,9 +121,9 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } if fn != nil { - log.Println("OK ", filepath.Base(m.Source)) + p.log.Println("OK ", filepath.Base(m.Source)) } else { - log.Println("EMPTY", filepath.Base(m.Source)) + p.log.Println("EMPTY", filepath.Base(m.Source)) } return nil @@ -125,7 +133,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } // NumericComponent looks for migration scripts with names in the form: -// XXX_descriptivename.ext where XXX specifies the version number +// XXX_descriptive_name.ext where XXX specifies the version number // and ext specifies the type of migration func NumericComponent(name string) (int64, error) { base := filepath.Base(name) diff --git a/migration_sql.go b/migration_sql.go index 359ebf6be..35e665b5e 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -15,11 +15,14 @@ import ( // // All statements following an Up or Down directive are grouped together // until another direction directive is found. -func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error { +func runSQLMigration(p *Provider, db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error { + if p == nil { + p = defaultProvider + } if useTx { // TRANSACTION. - verboseInfo("Begin transaction") + p.verboseInfo("Begin transaction") tx, err := db.Begin() if err != nil { @@ -27,9 +30,9 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc } for _, query := range statements { - verboseInfo("Executing statement: %s\n", clearStatement(query)) - if err = execQuery(tx.Exec, query); err != nil { - verboseInfo("Rollback transaction") + p.verboseInfo("Executing statement: %s\n", clearStatement(query)) + if err = p.execQuery(tx.Exec, query); err != nil { + p.verboseInfo("Rollback transaction") tx.Rollback() return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err) } @@ -37,21 +40,21 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc if !noVersioning { if direction { - if err := execQuery(tx.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil { - verboseInfo("Rollback transaction") + if err := p.execQuery(tx.Exec, p.dialect.insertVersionSQL(), v, direction); err != nil { + p.verboseInfo("Rollback transaction") tx.Rollback() return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := execQuery(tx.Exec, GetDialect().deleteVersionSQL(), v); err != nil { - verboseInfo("Rollback transaction") + if err := p.execQuery(tx.Exec, p.dialect.deleteVersionSQL(), v); err != nil { + p.verboseInfo("Rollback transaction") tx.Rollback() return fmt.Errorf("failed to delete goose version: %w", err) } } } - verboseInfo("Commit transaction") + p.verboseInfo("Commit transaction") if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } @@ -61,18 +64,18 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc // NO TRANSACTION. for _, query := range statements { - verboseInfo("Executing statement: %s", clearStatement(query)) - if err := execQuery(db.Exec, query); err != nil { + p.verboseInfo("Executing statement: %s", clearStatement(query)) + if err := p.execQuery(db.Exec, query); err != nil { return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err) } } if !noVersioning { if direction { - if err := execQuery(db.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil { + if err := p.execQuery(db.Exec, p.dialect.insertVersionSQL(), v, direction); err != nil { return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := execQuery(db.Exec, GetDialect().deleteVersionSQL(), v); err != nil { + if err := p.execQuery(db.Exec, p.dialect.deleteVersionSQL(), v); err != nil { return fmt.Errorf("failed to delete goose version: %w", err) } } @@ -81,8 +84,11 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc return nil } -func execQuery(fn func(string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error { - if !verbose { +func (p *Provider) execQuery(fn func(string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error { + if p == nil { + p = defaultProvider + } + if !p.verbose { _, err := fn(query, args...) return err } @@ -101,7 +107,7 @@ func execQuery(fn func(string, ...interface{}) (sql.Result, error), query string case err := <-ch: return err case <-time.Tick(time.Minute): - verboseInfo("Executing statement still in progress for %v", time.Since(t).Round(time.Second)) + p.verboseInfo("Executing statement still in progress for %v", time.Since(t).Round(time.Second)) } } } @@ -111,9 +117,12 @@ const ( resetColor = "\033[00m" ) -func verboseInfo(s string, args ...interface{}) { - if verbose { - log.Printf(grayColor+s+resetColor, args...) +func (p *Provider) verboseInfo(s string, args ...interface{}) { + if p == nil { + p = defaultProvider + } + if p.verbose { + p.log.Printf(grayColor+s+resetColor, args...) } } diff --git a/provider.go b/provider.go new file mode 100644 index 000000000..25cc0564b --- /dev/null +++ b/provider.go @@ -0,0 +1,165 @@ +package goose + +import ( + "io/fs" + "path/filepath" + "runtime" + "time" +) + +const ( + defaultProviderPackage = "migrations" + defaultTableName = "goose_db_version" + defaultTimestampFormat = "20060102150405" +) + +// defaultProvider is the provider the general functions use. +var defaultProvider = NewProvider() + +type providerOptions func(p *Provider) + +// TimestampFormat sets the timestamp format for the provider +func TimestampFormat(format string) func(p *Provider) { + return func(p *Provider) { + p.timestampFormat = format + } +} + +// TimeFunction sets the time function used to get the time for timestamp numbers +// defaults to time.Now +func TimeFunction(fn func() time.Time) func(p *Provider) { + if fn == nil { + fn = time.Now + } + return func(p *Provider) { + p.timeFn = fn + } +} + +// Verbose sets the verbose on the provider +func Verbose(b bool) func(p *Provider) { + return func(p *Provider) { + p.verbose = b + } +} + +// SequentialVersion make the provider use sequential versioning +func SequentialVersion(versionTemplate string) func(p *Provider) { + return func(p *Provider) { + p.sequential = true + if versionTemplate != "" { + p.seqVersionTemplate = versionTemplate + } + } +} + +// TimestampVersion make the provider use sequential versioning +func TimestampVersion(p *Provider) { + p.sequential = false +} +func Filesystem(baseFS fs.FS) func(p *Provider) { + return func(p *Provider) { + p.baseFS = baseFS + } +} + +func Log(log Logger) func(p *Provider) { + return func(p *Provider) { + p.log = log + } +} + +func Dialect(dialect string) func(p *Provider) { + return func(p *Provider) { + dialect, err := SelectDialect(p.tableName, dialect) + if err != nil { + p.log.Fatal(err) + } + p.dialect = dialect + } +} + +// dirPath finds the directory path of the calling function's caller +func dirPath() string { + _, filename, _, _ := runtime.Caller(2) + return filepath.Dir(filename) +} + +// BaseDir will set the base directory, if an empty string is passed +// the directory of the package that called BaseDir is used instead +// this is only useful for Create* and Fix functions +func BaseDir(dir string) func(p *Provider) { + if dir == "" { + dir = dirPath() + } + return func(p *Provider) { + p.baseDir = dir + } +} + +func DialectObject(dialect SQLDialect) func(p *Provider) { + return func(p *Provider) { + p.dialect = dialect + p.dialect.SetTableName(p.tableName) + } +} + +func Tablename(tablename string) func(p *Provider) { + return func(p *Provider) { + p.tableName = tablename + p.dialect.SetTableName(tablename) + } +} + +// ProviderPackage sets the packageName and providerVar used in templates +func ProviderPackage(packageName, providerVar string) func(p *Provider) { + if packageName == "" { + packageName = defaultProviderPackage + } + return func(p *Provider) { + p.packageName = packageName + p.providerVarName = providerVar + } +} + +type Provider struct { + timestampFormat string + // defaults to time.Now + timeFn func() time.Time + verbose bool + // whether to use sequential versioning instead of timestamp based versioning + sequential bool + baseFS fs.FS + log Logger + dialect SQLDialect + registeredGoMigrations map[int64]*Migration + tableName string + // seqVersionTemplate sets the template system will use this to format the digit of the sequence number + // by default it %05d, see seqVersionTemplate for actually default value. + seqVersionTemplate string + // packageName is the name of the package to use for Create functions + packageName string + // providerVarName is the name of the provider var for create functions + providerVarName string + // This is used for Create/Fix if the dir is not passed. + baseDir string +} + +func NewProvider(options ...providerOptions) *Provider { + p := &Provider{ + timestampFormat: defaultTimestampFormat, + timeFn: time.Now, + verbose: false, + sequential: false, + baseFS: osFS{}, + log: log, + dialect: &PostgresDialect{}, + registeredGoMigrations: map[int64]*Migration{}, + tableName: defaultTableName, + packageName: defaultProviderPackage, + } + for _, opt := range options { + opt(p) + } + return p +} diff --git a/redo.go b/redo.go index c485f9f67..91ba88a34 100644 --- a/redo.go +++ b/redo.go @@ -6,11 +6,13 @@ import ( // Redo rolls back the most recently applied migration, then runs it again. func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { - option := &options{} - for _, f := range opts { - f(option) - } - migrations, err := CollectMigrations(dir, minVersion, maxVersion) + return defaultProvider.Redo(db, dir, opts...) +} + +// Redo rolls back the most recently applied migration, then runs it again. +func (p *Provider) Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { + option := applyOptions(opts) + migrations, err := p.CollectMigrations(dir, minVersion, maxVersion) if err != nil { return err } @@ -23,7 +25,7 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { } currentVersion = migrations[len(migrations)-1].Version } else { - if currentVersion, err = GetDBVersion(db); err != nil { + if currentVersion, err = p.GetDBVersion(db); err != nil { return err } } @@ -34,10 +36,10 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { } current.noVersioning = option.noVersioning - if err := current.Down(db); err != nil { + if err := current.DownWithProvider(p, db); err != nil { return err } - if err := current.Up(db); err != nil { + if err := current.UpWithProvider(p, db); err != nil { return err } return nil diff --git a/reset.go b/reset.go index 258841fad..d38840f89 100644 --- a/reset.go +++ b/reset.go @@ -8,11 +8,13 @@ import ( // Reset rolls back all migrations func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { - option := &options{} - for _, f := range opts { - f(option) - } - migrations, err := CollectMigrations(dir, minVersion, maxVersion) + return defaultProvider.Reset(db, dir, opts...) +} + +// Reset rolls back all migrations +func (p *Provider) Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { + option := applyOptions(opts) + migrations, err := p.CollectMigrations(dir, minVersion, maxVersion) if err != nil { return fmt.Errorf("failed to collect migrations: %w", err) } @@ -20,7 +22,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { return DownTo(db, dir, minVersion, opts...) } - statuses, err := dbMigrationsStatus(db) + statuses, err := dbMigrationsStatus(p.dialect, db) if err != nil { return fmt.Errorf("failed to get status of migrations: %w", err) } @@ -30,7 +32,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { if !statuses[migration.Version] { continue } - if err = migration.Down(db); err != nil { + if err = migration.DownWithProvider(p, db); err != nil { return fmt.Errorf("failed to db-down: %w", err) } } @@ -38,8 +40,8 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { return nil } -func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) { - rows, err := GetDialect().dbVersionQuery(db) +func dbMigrationsStatus(dialect SQLDialect, db *sql.DB) (map[int64]bool, error) { + rows, err := dialect.dbVersionQuery(db) if err != nil { return map[int64]bool{}, nil } diff --git a/sql_parser.go b/sql_parser.go index 91bc84da0..3906e4912 100644 --- a/sql_parser.go +++ b/sql_parser.go @@ -30,7 +30,9 @@ func (s *stateMachine) Get() parserState { return parserState(*s) } func (s *stateMachine) Set(new parserState) { - verboseInfo("StateMachine: %v => %v", *s, new) + // Do not want to break contract, so for this one, we will just + // use the defaultProvider + defaultProvider.verboseInfo("StateMachine: %v => %v", *s, new) *s = stateMachine(new) } @@ -44,7 +46,7 @@ var bufferPool = sync.Pool{ }, } -// Split given SQL script into individual statements and return +// parseSQLMigration will split the given SQL-script into individual statements and return // SQL statements for given direction (up=true, down=false). // // The base case is to simply split on semicolons, as these @@ -54,7 +56,10 @@ var bufferPool = sync.Pool{ // within a statement. For these cases, we provide the explicit annotations // 'StatementBegin' and 'StatementEnd' to allow the script to // tell us to ignore semicolons. -func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool, err error) { +func parseSQLMigration(p *Provider, r io.Reader, direction bool) (stmts []string, useTx bool, err error) { + if p == nil { + p = defaultProvider + } var buf bytes.Buffer scanBuf := bufferPool.Get().([]byte) defer bufferPool.Put(scanBuf) @@ -67,8 +72,8 @@ func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool, for scanner.Scan() { line := scanner.Text() - if verbose { - log.Println(line) + if p.verbose { + p.log.Println(line) } if strings.HasPrefix(line, "--") { @@ -120,14 +125,14 @@ func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool, default: // Ignore comments. - verboseInfo("StateMachine: ignore comment") + p.verboseInfo("StateMachine: ignore comment") continue } } // Ignore empty lines. if matchEmptyLines.MatchString(line) { - verboseInfo("StateMachine: ignore empty line") + p.verboseInfo("StateMachine: ignore empty line") continue } @@ -145,13 +150,13 @@ func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool, case gooseUp, gooseStatementBeginUp, gooseStatementEndUp: if !direction /*down*/ { buf.Reset() - verboseInfo("StateMachine: ignore down") + p.verboseInfo("StateMachine: ignore down") continue } case gooseDown, gooseStatementBeginDown, gooseStatementEndDown: if direction /*up*/ { buf.Reset() - verboseInfo("StateMachine: ignore up") + p.verboseInfo("StateMachine: ignore up") continue } default: @@ -163,23 +168,23 @@ func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool, if endsWithSemicolon(line) { stmts = append(stmts, buf.String()) buf.Reset() - verboseInfo("StateMachine: store simple Up query") + p.verboseInfo("StateMachine: store simple Up query") } case gooseDown: if endsWithSemicolon(line) { stmts = append(stmts, buf.String()) buf.Reset() - verboseInfo("StateMachine: store simple Down query") + p.verboseInfo("StateMachine: store simple Down query") } case gooseStatementEndUp: stmts = append(stmts, buf.String()) buf.Reset() - verboseInfo("StateMachine: store Up statement") + p.verboseInfo("StateMachine: store Up statement") stateMachine.Set(gooseUp) case gooseStatementEndDown: stmts = append(stmts, buf.String()) buf.Reset() - verboseInfo("StateMachine: store Down statement") + p.verboseInfo("StateMachine: store Down statement") stateMachine.Set(gooseDown) } } diff --git a/sql_parser_test.go b/sql_parser_test.go index 6420ebb03..8cffaede8 100644 --- a/sql_parser_test.go +++ b/sql_parser_test.go @@ -55,7 +55,7 @@ func TestSplitStatements(t *testing.T) { for i, test := range tt { // up - stmts, _, err := parseSQLMigration(strings.NewReader(test.sql), true) + stmts, _, err := parseSQLMigration(nil, strings.NewReader(test.sql), true) if err != nil { t.Error(fmt.Errorf("tt[%v] unexpected error: %w", i, err)) } @@ -64,7 +64,7 @@ func TestSplitStatements(t *testing.T) { } // down - stmts, _, err = parseSQLMigration(strings.NewReader(test.sql), false) + stmts, _, err = parseSQLMigration(nil, strings.NewReader(test.sql), false) if err != nil { t.Error(fmt.Errorf("tt[%v] unexpected error: %w", i, err)) } @@ -93,7 +93,7 @@ func TestUseTransactions(t *testing.T) { if err != nil { t.Error(err) } - _, useTx, err := parseSQLMigration(f, true) + _, useTx, err := parseSQLMigration(nil, f, true) if err != nil { t.Error(err) } @@ -113,7 +113,7 @@ func TestParsingErrors(t *testing.T) { downFirst, } for i, sql := range tt { - _, _, err := parseSQLMigration(strings.NewReader(sql), true) + _, _, err := parseSQLMigration(nil, strings.NewReader(sql), true) if err == nil { t.Errorf("expected error on tt[%v] %q", i, sql) } diff --git a/status.go b/status.go index f53f1bece..085a7d28d 100644 --- a/status.go +++ b/status.go @@ -9,32 +9,33 @@ import ( // Status prints the status of all migrations. func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { - option := &options{} - for _, f := range opts { - f(option) - } - migrations, err := CollectMigrations(dir, minVersion, maxVersion) + return defaultProvider.Status(db, dir, opts...) +} + +func (p *Provider) Status(db *sql.DB, dir string, opts ...OptionsFunc) error { + option := applyOptions(opts) + migrations, err := p.CollectMigrations(dir, minVersion, maxVersion) if err != nil { return fmt.Errorf("failed to collect migrations: %w", err) } if option.noVersioning { - log.Println(" Applied At Migration") - log.Println(" =======================================") + p.log.Println(" Applied At Migration") + p.log.Println(" =======================================") for _, current := range migrations { - log.Printf(" %-24s -- %v\n", "no versioning", filepath.Base(current.Source)) + p.log.Printf(" %-24s -- %v\n", "no versioning", filepath.Base(current.Source)) } return nil } // must ensure that the version table exists if we're running on a pristine DB - if _, err := EnsureDBVersion(db); err != nil { + if _, err := p.EnsureDBVersion(db); err != nil { return fmt.Errorf("failed to ensure DB version: %w", err) } - log.Println(" Applied At Migration") - log.Println(" =======================================") + p.log.Println(" Applied At Migration") + p.log.Println(" =======================================") for _, migration := range migrations { - if err := printMigrationStatus(db, migration.Version, filepath.Base(migration.Source)); err != nil { + if err := p.printMigrationStatus(db, migration.Version, filepath.Base(migration.Source)); err != nil { return fmt.Errorf("failed to print status: %w", err) } } @@ -42,8 +43,8 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { return nil } -func printMigrationStatus(db *sql.DB, version int64, script string) error { - q := GetDialect().migrationSQL() +func (p *Provider) printMigrationStatus(db *sql.DB, version int64, script string) error { + q := p.dialect.migrationSQL() var row MigrationRecord @@ -59,6 +60,6 @@ func printMigrationStatus(db *sql.DB, version int64, script string) error { appliedAt = "Pending" } - log.Printf(" %-24s -- %v\n", appliedAt, script) + p.log.Printf(" %-24s -- %v\n", appliedAt, script) return nil } diff --git a/up.go b/up.go index d8d19cfe7..3078f004b 100644 --- a/up.go +++ b/up.go @@ -2,10 +2,11 @@ package goose import ( "database/sql" - "errors" "fmt" "sort" "strings" + + "github.com/pkg/errors" ) type options struct { @@ -28,13 +29,22 @@ func withApplyUpByOne() OptionsFunc { return func(o *options) { o.applyUpByOne = true } } -// UpTo migrates up to a specific version. -func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { - option := &options{} +func applyOptions(opts []OptionsFunc) *options { + option := new(options) for _, f := range opts { f(option) } - foundMigrations, err := CollectMigrations(dir, minVersion, version) + return option +} + +// UpTo migrates up to a specific version. +func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { + return defaultProvider.UpTo(db, dir, version, opts...) +} + +func (p *Provider) UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) (err error) { + option := applyOptions(opts) + foundMigrations, err := p.CollectMigrations(dir, minVersion, version) if err != nil { return err } @@ -48,13 +58,14 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // migration over and over. version = foundMigrations[0].Version } - return upToNoVersioning(db, foundMigrations, version) + return p.upToNoVersioning(db, foundMigrations, version) } - if _, err := EnsureDBVersion(db); err != nil { + if _, err := p.EnsureDBVersion(db); err != nil { return err } - dbMigrations, err := listAllDBVersions(db) + //dbMigrations, err := listAllDBVersions(p.dialect, db) + dbMigrations, err := listAllDBVersions(p.dialect, db) if err != nil { return err } @@ -75,7 +86,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { } if option.allowMissing { - return upWithMissing( + return p.upWithMissing( db, missingMigrations, foundMigrations, @@ -86,8 +97,8 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { var current int64 for { - var err error - current, err = GetDBVersion(db) + p.log.Println("GetDBVersion") + current, err = p.GetDBVersion(db) if err != nil { return err } @@ -98,7 +109,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { } return fmt.Errorf("failed to find next migration: %v", err) } - if err := next.Up(db); err != nil { + if err := next.UpWithProvider(p, db); err != nil { return err } if option.applyUpByOne { @@ -109,7 +120,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // the following behaviour: // UpByOne returns an error to signifying there are no more migrations. // Up and UpTo return nil - log.Printf("goose: no migrations to run. current version: %d\n", current) + p.log.Printf("goose: no migrations to run. current version: %d\n", current) if option.applyUpByOne { return ErrNoNextVersion } @@ -118,23 +129,23 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // upToNoVersioning applies up migrations up to, and including, the // target version. -func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func (p *Provider) upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { var finalVersion int64 for _, current := range migrations { if current.Version > version { break } current.noVersioning = true - if err := current.Up(db); err != nil { + if err := current.UpWithProvider(p, db); err != nil { return err } finalVersion = current.Version } - log.Printf("goose: up to current file version: %d\n", finalVersion) + p.log.Printf("goose: up to current file version: %d\n", finalVersion) return nil } -func upWithMissing( +func (p *Provider) upWithMissing( db *sql.DB, missingMigrations Migrations, foundMigrations Migrations, @@ -148,7 +159,7 @@ func upWithMissing( // Apply all missing migrations first. for _, missing := range missingMigrations { - if err := missing.Up(db); err != nil { + if err := missing.UpWithProvider(p, db); err != nil { return err } // Apply one migration and return early. @@ -159,7 +170,7 @@ func upWithMissing( // want to keep it as a safe-guard. Maybe we should instead have // the underlying query (if possible) return the current version as // part of the same transaction. - current, err := GetDBVersion(db) + current, err := p.GetDBVersion(db) if err != nil { return err } @@ -183,14 +194,14 @@ func upWithMissing( if lookupApplied[found.Version] { continue } - if err := found.Up(db); err != nil { + if err := found.UpWithProvider(p, db); err != nil { return err } if option.applyUpByOne { return nil } } - current, err := GetDBVersion(db) + current, err := p.GetDBVersion(db) if err != nil { return err } @@ -198,7 +209,7 @@ func upWithMissing( // the following behaviour: // UpByOne returns an error to signifying there are no more migrations. // Up and UpTo return nil - log.Printf("goose: no migrations to run. current version: %d\n", current) + p.log.Printf("goose: no migrations to run. current version: %d\n", current) if option.applyUpByOne { return ErrNoNextVersion } @@ -207,21 +218,32 @@ func upWithMissing( // Up applies all available migrations. func Up(db *sql.DB, dir string, opts ...OptionsFunc) error { - return UpTo(db, dir, maxVersion, opts...) + return defaultProvider.UpTo(db, dir, maxVersion, opts...) +} + +// Up applies all available migrations. +func (p *Provider) Up(db *sql.DB, dir string, opts ...OptionsFunc) error { + return p.UpTo(db, dir, maxVersion, opts...) } // UpByOne migrates up by a single version. func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { opts = append(opts, withApplyUpByOne()) - return UpTo(db, dir, maxVersion, opts...) + return defaultProvider.UpTo(db, dir, maxVersion, opts...) +} + +// UpByOne migrates up by a single version. +func (p *Provider) UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { + opts = append(opts, withApplyUpByOne()) + return p.UpTo(db, dir, maxVersion, opts...) } // listAllDBVersions returns a list of all migrations, ordered ascending. // TODO(mf): fairly cheap, but a nice-to-have is pagination support. -func listAllDBVersions(db *sql.DB) (Migrations, error) { - rows, err := GetDialect().dbVersionQuery(db) +func listAllDBVersions(dialect SQLDialect, db *sql.DB) (Migrations, error) { + rows, err := dialect.dbVersionQuery(db) if err != nil { - return nil, createVersionTable(db) + return nil, createVersionTable(dialect, db) } var all Migrations for rows.Next() { @@ -247,7 +269,7 @@ func listAllDBVersions(db *sql.DB) (Migrations, error) { } // findMissingMigrations migrations returns all missing migrations. -// A migrations is considered missing if it has a version less than the +// A migration is considered missing if it has a version less than the // current known max version. func findMissingMigrations(knownMigrations, newMigrations Migrations) Migrations { max := knownMigrations[len(knownMigrations)-1].Version @@ -256,9 +278,9 @@ func findMissingMigrations(knownMigrations, newMigrations Migrations) Migrations existing[known.Version] = true } var missing Migrations - for _, new := range newMigrations { - if !existing[new.Version] && new.Version < max { - missing = append(missing, new) + for _, newMigration := range newMigrations { + if !existing[newMigration.Version] && newMigration.Version < max { + missing = append(missing, newMigration) } } sort.SliceStable(missing, func(i, j int) bool { diff --git a/version.go b/version.go index 47765f728..942715c06 100644 --- a/version.go +++ b/version.go @@ -7,39 +7,50 @@ import ( // Version prints the current version of the database. func Version(db *sql.DB, dir string, opts ...OptionsFunc) error { - option := &options{} - for _, f := range opts { - f(option) - } + return defaultProvider.Version(db, dir, opts...) +} + +// Version prints the current version of the database. +func (p *Provider) Version(db *sql.DB, dir string, opts ...OptionsFunc) error { + option := applyOptions(opts) if option.noVersioning { var current int64 - migrations, err := CollectMigrations(dir, minVersion, maxVersion) + migrations, err := p.CollectMigrations(dir, minVersion, maxVersion) if err != nil { return fmt.Errorf("failed to collect migrations: %w", err) } if len(migrations) > 0 { current = migrations[len(migrations)-1].Version } - log.Printf("goose: file version %v\n", current) + p.log.Printf("goose: file version %v\n", current) return nil } - current, err := GetDBVersion(db) + current, err := p.GetDBVersion(db) if err != nil { return err } - log.Printf("goose: version %v\n", current) + p.log.Printf("goose: version %v\n", current) return nil } -var tableName = "goose_db_version" - // TableName returns goose db version table name func TableName() string { - return tableName + return defaultProvider.tableName +} + +// TableName returns goose db version table name +func (p *Provider) TableName() string { + return p.tableName } // SetTableName set goose db version table name func SetTableName(n string) { - tableName = n + defaultProvider.SetTableName(n) +} + +// SetTableName set goose db version table name +func (p *Provider) SetTableName(n string) { + p.tableName = n + p.dialect.SetTableName(n) }