From f844252197274b3d50df8059782766b42a38c634 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 4 Nov 2023 13:12:25 -0400 Subject: [PATCH 01/13] checkpoint 1 - move provider into package goose --- globals.go | 60 +- globals_test.go | 63 ++- internal/provider/migration.go | 186 ------- internal/provider/misc.go | 68 --- internal/provider/provider_test.go | 153 ------ internal/provider/run.go | 516 ------------------ internal/provider/types.go | 70 --- migration.go | 54 +- osfs.go | 8 + internal/provider/provider.go => provider.go | 67 +-- .../collect.go => provider_collect.go | 102 ++-- ...ollect_test.go => provider_collect_test.go | 126 ++--- .../provider/errors.go => provider_errors.go | 5 +- provider_impl.go | 206 +++++++ provider_migrate.go | 90 +++ ...provider_options.go => provider_options.go | 90 ++- ...ptions_test.go => provider_options_test.go | 30 +- provider_run.go | 356 ++++++++++++ .../run_test.go => provider_run_test.go | 253 ++++----- provider_test.go | 79 +++ register.go | 2 +- .../no-versioning/migrations/00001_a.sql | 0 .../no-versioning/migrations/00002_b.sql | 0 .../no-versioning/migrations/00003_c.sql | 0 .../no-versioning/seed/00001_a.sql | 0 .../no-versioning/seed/00002_b.sql | 0 types.go | 50 +- 27 files changed, 1183 insertions(+), 1451 deletions(-) delete mode 100644 internal/provider/migration.go delete mode 100644 internal/provider/misc.go delete mode 100644 internal/provider/provider_test.go delete mode 100644 internal/provider/run.go delete mode 100644 internal/provider/types.go rename internal/provider/provider.go => provider.go (84%) rename internal/provider/collect.go => provider_collect.go (75%) rename internal/provider/collect_test.go => provider_collect_test.go (77%) rename internal/provider/errors.go => provider_errors.go (89%) create mode 100644 provider_impl.go create mode 100644 provider_migrate.go rename internal/provider/provider_options.go => provider_options.go (63%) rename internal/provider/provider_options_test.go => provider_options_test.go (61%) create mode 100644 provider_run.go rename internal/provider/run_test.go => provider_run_test.go (77%) create mode 100644 provider_test.go rename {internal/provider/testdata => testdata}/no-versioning/migrations/00001_a.sql (100%) rename {internal/provider/testdata => testdata}/no-versioning/migrations/00002_b.sql (100%) rename {internal/provider/testdata => testdata}/no-versioning/migrations/00003_c.sql (100%) rename {internal/provider/testdata => testdata}/no-versioning/seed/00001_a.sql (100%) rename {internal/provider/testdata => testdata}/no-versioning/seed/00002_b.sql (100%) diff --git a/globals.go b/globals.go index e2d55faa8..535f0ff27 100644 --- a/globals.go +++ b/globals.go @@ -22,13 +22,12 @@ func ResetGlobalMigrations() { // [NewGoMigration] function. // // Not safe for concurrent use. -func SetGlobalMigrations(migrations ...Migration) error { - for _, migration := range migrations { - m := &migration +func SetGlobalMigrations(migrations ...*Migration) error { + for _, m := range migrations { if _, ok := registeredGoMigrations[m.Version]; ok { return fmt.Errorf("go migration with version %d already registered", m.Version) } - if err := checkMigration(m); err != nil { + if err := checkGoMigration(m); err != nil { return fmt.Errorf("invalid go migration: %w", err) } registeredGoMigrations[m.Version] = m @@ -36,7 +35,7 @@ func SetGlobalMigrations(migrations ...Migration) error { return nil } -func checkMigration(m *Migration) error { +func checkGoMigration(m *Migration) error { if !m.construct { return errors.New("must use NewGoMigration to construct migrations") } @@ -63,10 +62,10 @@ func checkMigration(m *Migration) error { return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source) } } - if err := setGoFunc(m.goUp); err != nil { + if err := checkGoFunc(m.goUp); err != nil { return fmt.Errorf("up function: %w", err) } - if err := setGoFunc(m.goDown); err != nil { + if err := checkGoFunc(m.goDown); err != nil { return fmt.Errorf("down function: %w", err) } if m.UpFnContext != nil && m.UpFnNoTxContext != nil { @@ -84,47 +83,22 @@ func checkMigration(m *Migration) error { return nil } -func setGoFunc(f *GoFunc) error { - if f == nil { - f = &GoFunc{Mode: TransactionEnabled} - return nil - } +func checkGoFunc(f *GoFunc) error { if f.RunTx != nil && f.RunDB != nil { return errors.New("must specify exactly one of RunTx or RunDB") } - if f.RunTx == nil && f.RunDB == nil { - switch f.Mode { - case 0: - // Default to TransactionEnabled ONLY if mode is not set explicitly. - f.Mode = TransactionEnabled - case TransactionEnabled, TransactionDisabled: - // No functions but mode is set. This is not an error. It means the user wants to record - // a version with the given mode but not run any functions. - default: - return fmt.Errorf("invalid mode: %d", f.Mode) - } - return nil - } - if f.RunDB != nil { - switch f.Mode { - case 0, TransactionDisabled: - f.Mode = TransactionDisabled - default: - return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set") - } + switch f.Mode { + case TransactionEnabled, TransactionDisabled: + // No functions, but mode is set. This is not an error. It means the user wants to + // record a version with the given mode but not run any functions. + default: + return fmt.Errorf("invalid mode: %d", f.Mode) } - if f.RunTx != nil { - switch f.Mode { - case 0, TransactionEnabled: - f.Mode = TransactionEnabled - default: - return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set") - } + if f.RunDB != nil && f.Mode != TransactionDisabled { + return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set") } - // This is a defensive check. If the mode is still 0, it means we failed to infer the mode from - // the functions or return an error. This should never happen. - if f.Mode == 0 { - return errors.New("failed to infer transaction mode") + if f.RunTx != nil && f.Mode != TransactionEnabled { + return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set") } return nil } diff --git a/globals_test.go b/globals_test.go index f40f01ff3..5ea5a0024 100644 --- a/globals_test.go +++ b/globals_test.go @@ -104,13 +104,15 @@ func TestTransactionMode(t *testing.T) { // reset so we can check the default is set migration2.goUp.Mode, migration2.goDown.Mode = 0, 0 err = SetGlobalMigrations(migration2) - check.NoError(t, err) - check.Number(t, len(registeredGoMigrations), 2) - registered = registeredGoMigrations[2] - check.Bool(t, registered.goUp != nil, true) - check.Bool(t, registered.goDown != nil, true) - check.Equal(t, registered.goUp.Mode, TransactionEnabled) - check.Equal(t, registered.goDown.Mode, TransactionEnabled) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0") + + migration3 := NewGoMigration(3, nil, nil) + // reset so we can check the default is set + migration3.goDown.Mode = 0 + err = SetGlobalMigrations(migration3) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0") }) t.Run("unknown_mode", func(t *testing.T) { m := NewGoMigration(1, nil, nil) @@ -192,7 +194,7 @@ func TestGlobalRegister(t *testing.T) { runTx := func(context.Context, *sql.Tx) error { return nil } // Success. - err := SetGlobalMigrations([]Migration{}...) + err := SetGlobalMigrations([]*Migration{}...) check.NoError(t, err) err = SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunTx: runTx}, nil), @@ -204,62 +206,79 @@ func TestGlobalRegister(t *testing.T) { ) check.HasError(t, err) check.Contains(t, err.Error(), "go migration with version 1 already registered") - err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo}) + err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo}) check.HasError(t, err) check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") } func TestCheckMigration(t *testing.T) { + // Success. + err := checkGoMigration(NewGoMigration(1, nil, nil)) + check.NoError(t, err) // Failures. - err := checkMigration(&Migration{}) + err = checkGoMigration(&Migration{}) check.HasError(t, err) check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") - err = checkMigration(&Migration{construct: true}) + err = checkGoMigration(&Migration{construct: true}) check.HasError(t, err) check.Contains(t, err.Error(), "must be registered") - err = checkMigration(&Migration{construct: true, Registered: true}) + err = checkGoMigration(&Migration{construct: true, Registered: true}) check.HasError(t, err) check.Contains(t, err.Error(), `type must be "go"`) - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo}) + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo}) check.HasError(t, err) check.Contains(t, err.Error(), "version must be greater than zero") + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}}) + check.HasError(t, err) + check.Contains(t, err.Error(), "up function: invalid mode: 0") + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}}) + check.HasError(t, err) + check.Contains(t, err.Error(), "down function: invalid mode: 0") // Success. - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1}) + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}}) check.NoError(t, err) // Failures. - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"}) + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"}) check.HasError(t, err) check.Contains(t, err.Error(), `source must have .go extension: "foo"`) - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"}) + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"}) check.HasError(t, err) check.Contains(t, err.Error(), `no filename separator '_' found`) - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"}) + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"}) check.HasError(t, err) check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`) - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"}) + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"}) check.HasError(t, err) check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`) - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, UpFnContext: func(context.Context, *sql.Tx) error { return nil }, UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, + goUp: &GoFunc{Mode: TransactionEnabled}, + goDown: &GoFunc{Mode: TransactionEnabled}, }) check.HasError(t, err) check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, DownFnContext: func(context.Context, *sql.Tx) error { return nil }, DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, + goUp: &GoFunc{Mode: TransactionEnabled}, + goDown: &GoFunc{Mode: TransactionEnabled}, }) check.HasError(t, err) check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, UpFn: func(*sql.Tx) error { return nil }, UpFnNoTx: func(*sql.DB) error { return nil }, + goUp: &GoFunc{Mode: TransactionEnabled}, + goDown: &GoFunc{Mode: TransactionEnabled}, }) check.HasError(t, err) check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx") - err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, DownFn: func(*sql.Tx) error { return nil }, DownFnNoTx: func(*sql.DB) error { return nil }, + goUp: &GoFunc{Mode: TransactionEnabled}, + goDown: &GoFunc{Mode: TransactionEnabled}, }) check.HasError(t, err) check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx") diff --git a/internal/provider/migration.go b/internal/provider/migration.go deleted file mode 100644 index 07508ff68..000000000 --- a/internal/provider/migration.go +++ /dev/null @@ -1,186 +0,0 @@ -package provider - -import ( - "context" - "database/sql" - "fmt" - "path/filepath" - - "github.com/pressly/goose/v3/database" -) - -type migration struct { - Source Source - // A migration is either a Go migration or a SQL migration, but never both. - // - // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is - // an optimization to avoid parsing the SQL migration if it is never required. Also, the - // majority of the time migrations are incremental, so it is likely that the user will only want - // to run the last few migrations, and there is no need to parse ALL prior migrations. - // - // Exactly one of these fields will be set: - Go *goMigration - // -- OR -- - SQL *sqlMigration -} - -func (m *migration) useTx(direction bool) bool { - switch m.Source.Type { - case TypeSQL: - return m.SQL.UseTx - case TypeGo: - if m.Go == nil || m.Go.isEmpty(direction) { - return false - } - if direction { - return m.Go.up.Run != nil - } - return m.Go.down.Run != nil - } - // This should never happen. - return false -} - -func (m *migration) isEmpty(direction bool) bool { - switch m.Source.Type { - case TypeSQL: - return m.SQL == nil || m.SQL.isEmpty(direction) - case TypeGo: - return m.Go == nil || m.Go.isEmpty(direction) - } - return true -} - -func (m *migration) filename() string { - return filepath.Base(m.Source.Path) -} - -// run runs the migration inside of a transaction. -func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error { - switch m.Source.Type { - case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("tx: sql migration has not been parsed") - } - return m.SQL.run(ctx, tx, direction) - case TypeGo: - return m.Go.run(ctx, tx, direction) - } - // This should never happen. - return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path)) -} - -// runNoTx runs the migration without a transaction. -func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { - switch m.Source.Type { - case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("db: sql migration has not been parsed") - } - return m.SQL.run(ctx, db, direction) - case TypeGo: - return m.Go.runNoTx(ctx, db, direction) - } - // This should never happen. - return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path)) -} - -// runConn runs the migration without a transaction using the provided connection. -func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error { - switch m.Source.Type { - case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("conn: sql migration has not been parsed") - } - return m.SQL.run(ctx, conn, direction) - case TypeGo: - return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") - } - // This should never happen. - return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path)) -} - -type goMigration struct { - fullpath string - up, down *GoMigrationFunc -} - -func (g *goMigration) isEmpty(direction bool) bool { - if g.up == nil && g.down == nil { - panic("go migration has no up or down") - } - if direction { - return g.up == nil - } - return g.down == nil -} - -func newGoMigration(fullpath string, up, down *GoMigrationFunc) *goMigration { - return &goMigration{ - fullpath: fullpath, - up: up, - down: down, - } -} - -func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error { - if g == nil { - return nil - } - var fn func(context.Context, *sql.Tx) error - if direction && g.up != nil { - fn = g.up.Run - } - if !direction && g.down != nil { - fn = g.down.Run - } - if fn != nil { - return fn(ctx, tx) - } - return nil -} - -func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { - if g == nil { - return nil - } - var fn func(context.Context, *sql.DB) error - if direction && g.up != nil { - fn = g.up.RunNoTx - } - if !direction && g.down != nil { - fn = g.down.RunNoTx - } - if fn != nil { - return fn(ctx, db) - } - return nil -} - -type sqlMigration struct { - UseTx bool - UpStatements []string - DownStatements []string -} - -func (s *sqlMigration) isEmpty(direction bool) bool { - if direction { - return len(s.UpStatements) == 0 - } - return len(s.DownStatements) == 0 -} - -func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error { - var statements []string - if direction { - statements = s.UpStatements - } else { - statements = s.DownStatements - } - for _, stmt := range statements { - if _, err := db.ExecContext(ctx, stmt); err != nil { - return err - } - } - return nil -} diff --git a/internal/provider/misc.go b/internal/provider/misc.go deleted file mode 100644 index e20fbad18..000000000 --- a/internal/provider/misc.go +++ /dev/null @@ -1,68 +0,0 @@ -package provider - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/pressly/goose/v3" -) - -type MigrationCopy struct { - Version int64 - Source string // path to .sql script or go file - Registered bool - UpFnContext, DownFnContext func(context.Context, *sql.Tx) error - UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error -} - -var registeredGoMigrations = make(map[int64]*MigrationCopy) - -// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of -// the migrations are nil or if a migration with the same version has already been registered. -// -// Not safe for concurrent use. -func SetGlobalGoMigrations(migrations []*MigrationCopy) error { - for _, m := range migrations { - if m == nil { - return errors.New("cannot register nil go migration") - } - if m.Version < 1 { - return errors.New("migration versions must be greater than zero") - } - if !m.Registered { - return errors.New("migration must be registered") - } - if m.Source != "" { - // If the source is set, expect it to be a file path with a numeric component that - // matches the version. - version, err := goose.NumericComponent(m.Source) - if err != nil { - return err - } - if version != m.Version { - return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source) - } - } - // It's valid for all of these to be nil. - if m.UpFnContext != nil && m.UpFnNoTxContext != nil { - return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext") - } - if m.DownFnContext != nil && m.DownFnNoTxContext != nil { - return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext") - } - if _, ok := registeredGoMigrations[m.Version]; ok { - return fmt.Errorf("go migration with version %d already registered", m.Version) - } - registeredGoMigrations[m.Version] = m - } - return nil -} - -// ResetGlobalGoMigrations resets the global go migrations registry. -// -// Not safe for concurrent use. -func ResetGlobalGoMigrations() { - registeredGoMigrations = make(map[int64]*MigrationCopy) -} diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go deleted file mode 100644 index 3c1268854..000000000 --- a/internal/provider/provider_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package provider_test - -import ( - "context" - "database/sql" - "errors" - "io/fs" - "path/filepath" - "testing" - "testing/fstest" - - "github.com/pressly/goose/v3/database" - "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/provider" - _ "modernc.org/sqlite" -) - -func TestProvider(t *testing.T) { - dir := t.TempDir() - db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) - check.NoError(t, err) - t.Run("empty", func(t *testing.T) { - _, err := provider.NewProvider(database.DialectSQLite3, db, fstest.MapFS{}) - check.HasError(t, err) - check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true) - }) - - mapFS := fstest.MapFS{ - "migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)}, - "migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)}, - } - fsys, err := fs.Sub(mapFS, "migrations") - check.NoError(t, err) - p, err := provider.NewProvider(database.DialectSQLite3, db, fsys) - check.NoError(t, err) - sources := p.ListSources() - check.Equal(t, len(sources), 2) - check.Equal(t, sources[0], newSource(provider.TypeSQL, "001_foo.sql", 1)) - check.Equal(t, sources[1], newSource(provider.TypeSQL, "002_bar.sql", 2)) - - t.Run("duplicate_go", func(t *testing.T) { - // Not parallel because it modifies global state. - register := []*provider.MigrationCopy{ - { - Version: 1, Source: "00001_users_table.go", Registered: true, - UpFnContext: nil, - DownFnContext: nil, - }, - } - err := provider.SetGlobalGoMigrations(register) - check.NoError(t, err) - t.Cleanup(provider.ResetGlobalGoMigrations) - - db := newDB(t) - _, err = provider.NewProvider(database.DialectSQLite3, db, nil, - provider.WithGoMigration(1, nil, nil), - ) - check.HasError(t, err) - check.Equal(t, err.Error(), "go migration with version 1 already registered") - }) - t.Run("empty_go", func(t *testing.T) { - db := newDB(t) - // explicit - _, err := provider.NewProvider(database.DialectSQLite3, db, nil, - provider.WithGoMigration(1, &provider.GoMigrationFunc{Run: nil}, &provider.GoMigrationFunc{Run: nil}), - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "go migration with version 1 must have an up function") - }) - t.Run("duplicate_up", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ - { - Version: 1, Source: "00001_users_table.go", Registered: true, - UpFnContext: func(context.Context, *sql.Tx) error { return nil }, - UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, - }, - }) - t.Cleanup(provider.ResetGlobalGoMigrations) - check.HasError(t, err) - check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") - }) - t.Run("duplicate_down", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ - { - Version: 1, Source: "00001_users_table.go", Registered: true, - DownFnContext: func(context.Context, *sql.Tx) error { return nil }, - DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, - }, - }) - t.Cleanup(provider.ResetGlobalGoMigrations) - check.HasError(t, err) - check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") - }) - t.Run("not_registered", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ - { - Version: 1, Source: "00001_users_table.go", - }, - }) - t.Cleanup(provider.ResetGlobalGoMigrations) - check.HasError(t, err) - check.Contains(t, err.Error(), "migration must be registered") - }) - t.Run("zero_not_allowed", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ - { - Version: 0, - }, - }) - t.Cleanup(provider.ResetGlobalGoMigrations) - check.HasError(t, err) - check.Contains(t, err.Error(), "migration versions must be greater than zero") - }) -} - -var ( - migration1 = ` --- +goose Up -CREATE TABLE foo (id INTEGER PRIMARY KEY); --- +goose Down -DROP TABLE foo; -` - migration2 = ` --- +goose Up -ALTER TABLE foo ADD COLUMN name TEXT; --- +goose Down -ALTER TABLE foo DROP COLUMN name; -` - migration3 = ` --- +goose Up -CREATE TABLE bar ( - id INTEGER PRIMARY KEY, - description TEXT -); --- +goose Down -DROP TABLE bar; -` - migration4 = ` --- +goose Up --- Rename the 'foo' table to 'my_foo' -ALTER TABLE foo RENAME TO my_foo; - --- Add a new column 'timestamp' to 'my_foo' -ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; - --- +goose Down --- Remove the 'timestamp' column from 'my_foo' -ALTER TABLE my_foo DROP COLUMN timestamp; - --- Rename the 'my_foo' table back to 'foo' -ALTER TABLE my_foo RENAME TO foo; -` -) diff --git a/internal/provider/run.go b/internal/provider/run.go deleted file mode 100644 index c5f63f13b..000000000 --- a/internal/provider/run.go +++ /dev/null @@ -1,516 +0,0 @@ -package provider - -import ( - "context" - "database/sql" - "errors" - "fmt" - "io/fs" - "sort" - "strings" - "time" - - "github.com/pressly/goose/v3/database" - "github.com/pressly/goose/v3/internal/sqlparser" - "go.uber.org/multierr" -) - -var ( - errMissingZeroVersion = errors.New("missing zero version migration") -) - -func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) { - if version < 1 { - return nil, errors.New("version must be greater than zero") - } - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - if len(p.migrations) == 0 { - return nil, nil - } - var apply []*migration - if p.cfg.disableVersioning { - apply = p.migrations - } else { - // optimize(mf): Listing all migrations from the database isn't great. This is only required - // to support the allow missing (out-of-order) feature. For users that don't use this - // feature, we could just query the database for the current max version and then apply - // migrations greater than that version. - dbMigrations, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return nil, err - } - if len(dbMigrations) == 0 { - return nil, errMissingZeroVersion - } - apply, err = p.resolveUpMigrations(dbMigrations, version) - if err != nil { - return nil, err - } - } - // feat(mf): this is where can (optionally) group multiple migrations to be run in a single - // transaction. The default is to apply each migration sequentially on its own. - // https://github.com/pressly/goose/issues/222 - // - // Careful, we can't use a single transaction for all migrations because some may have to be run - // in their own transaction. - return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne) -} - -func (p *Provider) resolveUpMigrations( - dbVersions []*database.ListMigrationsResult, - version int64, -) ([]*migration, error) { - var apply []*migration - var dbMaxVersion int64 - // dbAppliedVersions is a map of all applied migrations in the database. - dbAppliedVersions := make(map[int64]bool, len(dbVersions)) - for _, m := range dbVersions { - dbAppliedVersions[m.Version] = true - if m.Version > dbMaxVersion { - dbMaxVersion = m.Version - } - } - missingMigrations := checkMissingMigrations(dbVersions, p.migrations) - // feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing - // migrations entirely. At the moment this is not supported, but leaving this comment because - // that's where that logic would be handled. - // - // For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not - // sure if this is a common use case, but it's possible. - if len(missingMigrations) > 0 && !p.cfg.allowMissing { - var collected []string - for _, v := range missingMigrations { - collected = append(collected, v.filename) - } - msg := "migration" - if len(collected) > 1 { - msg += "s" - } - return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]", - len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","), - ) - } - for _, v := range missingMigrations { - m, err := p.getMigration(v.versionID) - if err != nil { - return nil, err - } - apply = append(apply, m) - } - // filter all migrations with a version greater than the supplied version (min) and less than or - // equal to the requested version (max). Skip any migrations that have already been applied. - for _, m := range p.migrations { - if dbAppliedVersions[m.Source.Version] { - continue - } - if m.Source.Version > dbMaxVersion && m.Source.Version <= version { - apply = append(apply, m) - } - } - return apply, nil -} - -func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) { - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - if len(p.migrations) == 0 { - return nil, nil - } - if p.cfg.disableVersioning { - downMigrations := p.migrations - if downByOne { - last := p.migrations[len(p.migrations)-1] - downMigrations = []*migration{last} - } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) - } - dbMigrations, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return nil, err - } - if len(dbMigrations) == 0 { - return nil, errMissingZeroVersion - } - if dbMigrations[0].Version == 0 { - return nil, nil - } - var downMigrations []*migration - for _, dbMigration := range dbMigrations { - if dbMigration.Version <= version { - break - } - m, err := p.getMigration(dbMigration.Version) - if err != nil { - return nil, err - } - downMigrations = append(downMigrations, m) - } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) -} - -// runMigrations runs migrations sequentially in the given direction. -// -// If the migrations list is empty, return nil without error. -func (p *Provider) runMigrations( - ctx context.Context, - conn *sql.Conn, - migrations []*migration, - direction sqlparser.Direction, - byOne bool, -) ([]*MigrationResult, error) { - if len(migrations) == 0 { - return nil, nil - } - apply := migrations - if byOne { - apply = migrations[:1] - } - // Lazily parse SQL migrations (if any) in both directions. We do this before running any - // migrations so that we can fail fast if there are any errors and avoid leaving the database in - // a partially migrated state. - if err := parseSQL(p.fsys, false, apply); err != nil { - return nil, err - } - // feat(mf): If we decide to add support for advisory locks at the transaction level, this may - // be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe - // to run in a transaction. - - // bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but - // are locking the database with *sql.Conn. If the caller sets max open connections to 1, then - // this will deadlock because the Go migration will try to acquire a connection from the pool, - // but the pool is locked. - // - // A potential solution is to expose a third Go register function *sql.Conn. Or continue to use - // *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of - // an edge case. - if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 { - for _, m := range apply { - switch m.Source.Type { - case TypeGo: - if m.Go != nil && m.useTx(direction.ToBool()) { - return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1") - } - } - } - } - - // Avoid allocating a slice because we may have a partial migration error. - // 1. Avoid giving the impression that N migrations were applied when in fact some were not - // 2. Avoid the caller having to check for nil results - var results []*MigrationResult - for _, m := range apply { - current := &MigrationResult{ - Source: m.Source, - Direction: direction.String(), - Empty: m.isEmpty(direction.ToBool()), - } - start := time.Now() - if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil { - // TODO(mf): we should also return the pending migrations here, the remaining items in - // the apply slice. - current.Error = err - current.Duration = time.Since(start) - return nil, &PartialError{ - Applied: results, - Failed: current, - Err: err, - } - } - current.Duration = time.Since(start) - results = append(results, current) - } - return results, nil -} - -func (p *Provider) runIndividually( - ctx context.Context, - conn *sql.Conn, - direction bool, - m *migration, -) error { - if m.useTx(direction) { - // Run the migration in a transaction. - return p.beginTx(ctx, conn, func(tx *sql.Tx) error { - if err := m.run(ctx, tx, direction); err != nil { - return err - } - if p.cfg.disableVersioning { - return nil - } - if direction { - return p.store.Insert(ctx, tx, database.InsertRequest{Version: m.Source.Version}) - } - return p.store.Delete(ctx, tx, m.Source.Version) - }) - } - // Run the migration outside of a transaction. - switch m.Source.Type { - case TypeGo: - // Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the - // GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open - // connections to 1. See the comment in runMigrations for more details. - if err := m.runNoTx(ctx, p.db, direction); err != nil { - return err - } - case TypeSQL: - if err := m.runConn(ctx, conn, direction); err != nil { - return err - } - } - if p.cfg.disableVersioning { - return nil - } - if direction { - return p.store.Insert(ctx, conn, database.InsertRequest{Version: m.Source.Version}) - } - return p.store.Delete(ctx, conn, m.Source.Version) -} - -// beginTx begins a transaction and runs the given function. If the function returns an error, the -// transaction is rolled back. Otherwise, the transaction is committed. -// -// If the provider is configured to use versioning, this function also inserts or deletes the -// migration version. -func (p *Provider) beginTx( - ctx context.Context, - conn *sql.Conn, - fn func(tx *sql.Tx) error, -) (retErr error) { - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { - if retErr != nil { - retErr = multierr.Append(retErr, tx.Rollback()) - } - }() - if err := fn(tx); err != nil { - return err - } - return tx.Commit() -} - -func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) { - p.mu.Lock() - conn, err := p.db.Conn(ctx) - if err != nil { - p.mu.Unlock() - return nil, nil, err - } - // cleanup is a function that cleans up the connection, and optionally, the session lock. - cleanup := func() error { - p.mu.Unlock() - return conn.Close() - } - if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled { - if err := l.SessionLock(ctx, conn); err != nil { - return nil, nil, multierr.Append(err, cleanup()) - } - cleanup = func() error { - p.mu.Unlock() - // Use a detached context to unlock the session. This is because the context passed to - // SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf): - // use [context.WithoutCancel] added in go1.21 - detachedCtx := context.Background() - return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) - } - } - // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't - // need the version table because there is no versioning. - if !p.cfg.disableVersioning { - if err := p.ensureVersionTable(ctx, conn); err != nil { - return nil, nil, multierr.Append(err, cleanup()) - } - } - return conn, cleanup, nil -} - -// parseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it -// will not be parsed again. -// -// Important: This function will mutate SQL migrations and is not safe for concurrent use. -func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error { - for _, m := range migrations { - // If the migration is a SQL migration, and it has not been parsed, parse it. - if m.Source.Type == TypeSQL && m.SQL == nil { - parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug) - if err != nil { - return err - } - m.SQL = &sqlMigration{ - UseTx: parsed.UseTx, - UpStatements: parsed.Up, - DownStatements: parsed.Down, - } - } - } - return nil -} - -func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { - // feat(mf): this is where we can check if the version table exists instead of trying to fetch - // from a table that may not exist. https://github.com/pressly/goose/issues/461 - res, err := p.store.GetMigration(ctx, conn, 0) - if err == nil && res != nil { - return nil - } - return p.beginTx(ctx, conn, func(tx *sql.Tx) error { - if err := p.store.CreateVersionTable(ctx, tx); err != nil { - return err - } - if p.cfg.disableVersioning { - return nil - } - return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0}) - }) -} - -type missingMigration struct { - versionID int64 - filename string -} - -// checkMissingMigrations returns a list of migrations that are missing from the database. A missing -// migration is one that has a version less than the max version in the database. -func checkMissingMigrations( - dbMigrations []*database.ListMigrationsResult, - fsMigrations []*migration, -) []missingMigration { - existing := make(map[int64]bool) - var dbMaxVersion int64 - for _, m := range dbMigrations { - existing[m.Version] = true - if m.Version > dbMaxVersion { - dbMaxVersion = m.Version - } - } - var missing []missingMigration - for _, m := range fsMigrations { - version := m.Source.Version - if !existing[version] && version < dbMaxVersion { - missing = append(missing, missingMigration{ - versionID: version, - filename: m.filename(), - }) - } - } - sort.Slice(missing, func(i, j int) bool { - return missing[i].versionID < missing[j].versionID - }) - return missing -} - -// getMigration returns the migration with the given version. If no migration is found, then -// ErrVersionNotFound is returned. -func (p *Provider) getMigration(version int64) (*migration, error) { - for _, m := range p.migrations { - if m.Source.Version == version { - return m, nil - } - } - return nil, ErrVersionNotFound -} - -func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) { - m, err := p.getMigration(version) - if err != nil { - return nil, err - } - - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - result, err := p.store.GetMigration(ctx, conn, version) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - // If the migration has already been applied, return an error, unless the migration is being - // applied in the opposite direction. In that case, we allow the migration to be applied again. - if result != nil && direction { - return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) - } - - d := sqlparser.DirectionDown - if direction { - d = sqlparser.DirectionUp - } - results, err := p.runMigrations(ctx, conn, []*migration{m}, d, true) - if err != nil { - return nil, err - } - if len(results) == 0 { - return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) - } - return results[0], nil -} - -func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to - // support limiting the set. - - status := make([]*MigrationStatus, 0, len(p.migrations)) - for _, m := range p.migrations { - migrationStatus := &MigrationStatus{ - Source: m.Source, - State: StatePending, - } - dbResult, err := p.store.GetMigration(ctx, conn, m.Source.Version) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - if dbResult != nil { - migrationStatus.State = StateApplied - migrationStatus.AppliedAt = dbResult.Timestamp - } - status = append(status, migrationStatus) - } - - return status, nil -} - -func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return 0, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - res, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return 0, err - } - if len(res) == 0 { - return 0, nil - } - sort.Slice(res, func(i, j int) bool { - return res[i].Version > res[j].Version - }) - return res[0].Version, nil -} diff --git a/internal/provider/types.go b/internal/provider/types.go deleted file mode 100644 index 3480c7498..000000000 --- a/internal/provider/types.go +++ /dev/null @@ -1,70 +0,0 @@ -package provider - -import ( - "fmt" - "time" -) - -// MigrationType is the type of migration. -type MigrationType int - -const ( - TypeGo MigrationType = iota + 1 - TypeSQL -) - -func (t MigrationType) String() string { - switch t { - case TypeGo: - return "go" - case TypeSQL: - return "sql" - default: - // This should never happen. - return fmt.Sprintf("unknown (%d)", t) - } -} - -// Source represents a single migration source. -// -// The Path field may be empty if the migration was registered manually. This is typically the case -// for Go migrations registered using the [WithGoMigration] option. -type Source struct { - Type MigrationType - Path string - Version int64 -} - -// MigrationResult is the result of a single migration operation. -type MigrationResult struct { - Source Source - Duration time.Duration - Direction string - // Empty indicates no action was taken during the migration, but it was still versioned. For - // SQL, it means no statements; for Go, it's a nil function. - Empty bool - // Error is only set if the migration failed. - Error error -} - -// State represents the state of a migration. -type State string - -const ( - // StatePending is a migration that exists on the filesystem, but not in the database. - StatePending State = "pending" - // StateApplied is a migration that has been applied to the database and exists on the - // filesystem. - StateApplied State = "applied" - - // TODO(mf): we could also add a third state for untracked migrations. This would be useful for - // migrations that were manually applied to the database, but not versioned. Or the Source was - // deleted, but the migration still exists in the database. StateUntracked State = "untracked" -) - -// MigrationStatus represents the status of a single migration. -type MigrationStatus struct { - Source Source - State State - AppliedAt time.Time -} diff --git a/migration.go b/migration.go index be23d038f..d732504cd 100644 --- a/migration.go +++ b/migration.go @@ -18,22 +18,36 @@ import ( // Both up and down functions may be nil, in which case the migration will be recorded in the // versions table but no functions will be run. This is useful for recording (up) or deleting (down) // a version without running any functions. See [GoFunc] for more details. -func NewGoMigration(version int64, up, down *GoFunc) Migration { - m := Migration{ +func NewGoMigration(version int64, up, down *GoFunc) *Migration { + m := &Migration{ Type: TypeGo, Registered: true, Version: version, Next: -1, Previous: -1, - goUp: up, - goDown: down, + goUp: &GoFunc{Mode: TransactionEnabled}, + goDown: &GoFunc{Mode: TransactionEnabled}, construct: true, } + updateMode := func(f *GoFunc) *GoFunc { + // infer mode from function + if f.Mode == 0 { + if f.RunTx != nil && f.RunDB == nil { + f.Mode = TransactionEnabled + } + if f.RunTx == nil && f.RunDB != nil { + f.Mode = TransactionDisabled + } + } + return f + } // To maintain backwards compatibility, we set ALL legacy functions. In a future major version, // we will remove these fields in favor of [GoFunc]. // // Note, this function does not do any validation. Validation is lazily done when the migration // is registered. if up != nil { + m.goUp = updateMode(up) + if up.RunDB != nil { m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error @@ -45,6 +59,8 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration { } } if down != nil { + m.goDown = updateMode(down) + if down.RunDB != nil { m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error @@ -55,12 +71,6 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration { m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error } } - if m.goUp == nil { - m.goUp = &GoFunc{Mode: TransactionEnabled} - } - if m.goDown == nil { - m.goDown = &GoFunc{Mode: TransactionEnabled} - } return m } @@ -76,10 +86,6 @@ type Migration struct { UpFnContext, DownFnContext GoMigrationContext UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext - // These fields are used internally by goose and users are not expected to set them. Instead, - // use [NewGoMigration] to create a new go migration. - construct bool - goUp, goDown *GoFunc // These fields will be removed in a future major version. They are here for backwards // compatibility and are an implementation detail. @@ -98,6 +104,26 @@ type Migration struct { DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead. noVersioning bool + + // These fields are used internally by goose and users are not expected to set them. Instead, + // use [NewGoMigration] to create a new go migration. + construct bool + goUp, goDown *GoFunc + + sql sqlMigration +} + +type sqlMigration struct { + // The Parsed field is used to track whether the SQL migration has been parsed. It serves as an + // optimization to avoid parsing migrations that may never be needed. Typically, migrations are + // incremental, and users often run only the most recent ones, making parsing of prior + // migrations unnecessary in most cases. + Parsed bool + + // Parsed must be set to true before the following fields are used. + UseTx bool + Up []string + Down []string } // GoFunc represents a Go migration function. diff --git a/osfs.go b/osfs.go index e64f7706b..420f95ff9 100644 --- a/osfs.go +++ b/osfs.go @@ -18,3 +18,11 @@ func (osFS) Stat(name string) (fs.FileInfo, error) { return os.Stat(filepath.Fro func (osFS) ReadFile(name string) ([]byte, error) { return os.ReadFile(filepath.FromSlash(name)) } func (osFS) Glob(pattern string) ([]string, error) { return filepath.Glob(filepath.FromSlash(pattern)) } + +type noopFS struct{} + +var _ fs.FS = noopFS{} + +func (f noopFS) Open(name string) (fs.File, error) { + return nil, os.ErrNotExist +} diff --git a/internal/provider/provider.go b/provider.go similarity index 84% rename from internal/provider/provider.go rename to provider.go index bd68e2ff1..d67a01bfb 100644 --- a/internal/provider/provider.go +++ b/provider.go @@ -1,4 +1,4 @@ -package provider +package goose import ( "context" @@ -12,7 +12,7 @@ import ( "github.com/pressly/goose/v3/database" ) -// Provider is a goose migration provider. +// Provider is a goose migration goose. type Provider struct { // mu protects all accesses to the provider and must be held when calling operations on the // database. @@ -24,10 +24,10 @@ type Provider struct { store database.Store // migrations are ordered by version in ascending order. - migrations []*migration + migrations []*Migration } -// NewProvider returns a new goose Provider. +// NewProvider returns a new goose goose. // // The caller is responsible for matching the database dialect with the database/sql driver. For // example, if the database dialect is "postgres", the database/sql driver could be @@ -40,7 +40,7 @@ type Provider struct { // However, it is possible to use a different "filesystem", such as [embed.FS] or filter out // migrations using [fs.Sub]. // -// See [ProviderOption] for more information on configuring the provider. +// See [ProviderOption] for more information on configuring the goose. // // Unless otherwise specified, all methods on Provider are safe for concurrent use. // @@ -53,8 +53,9 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi fsys = noopFS{} } cfg := config{ - registered: make(map[int64]*goMigration), - excludes: make(map[string]bool), + registered: make(map[int64]*Migration), + excludePaths: make(map[string]bool), + excludeVersions: make(map[int64]bool), } for _, opt := range opts { if err := opt.apply(&cfg); err != nil { @@ -90,7 +91,7 @@ func newProvider( store database.Store, fsys fs.FS, cfg config, - global map[int64]*MigrationCopy, + global map[int64]*Migration, ) (*Provider, error) { // Collect migrations from the filesystem and merge with registered migrations. // @@ -100,54 +101,24 @@ func newProvider( // TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to // return an error if there are any SQL parsing errors. This adds a bit overhead to startup // though, so we should make it optional. - filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludes) + filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions) if err != nil { return nil, err } - registered := make(map[int64]*goMigration) + versionToGoMigration := make(map[int64]*Migration) // Add user-registered Go migrations. for version, m := range cfg.registered { - registered[version] = newGoMigration("", m.up, m.down) + versionToGoMigration[version] = m } // Add init() functions. This is a bit ugly because we need to convert from the old Migration // struct to the new goMigration struct. for version, m := range global { - if _, ok := registered[version]; ok { - return nil, fmt.Errorf("go migration with version %d already registered", version) + if _, ok := versionToGoMigration[version]; ok { + return nil, fmt.Errorf("global go migration with version %d already registered with provider", version) } - if m == nil { - return nil, errors.New("registered migration with nil init function") - } - g := newGoMigration(m.Source, nil, nil) - if m.UpFnContext != nil && m.UpFnNoTxContext != nil { - return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext") - } - if m.DownFnContext != nil && m.DownFnNoTxContext != nil { - return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext") - } - // Up - if m.UpFnContext != nil { - g.up = &GoMigrationFunc{ - Run: m.UpFnContext, - } - } else if m.UpFnNoTxContext != nil { - g.up = &GoMigrationFunc{ - RunNoTx: m.UpFnNoTxContext, - } - } - // Down - if m.DownFnContext != nil { - g.down = &GoMigrationFunc{ - Run: m.DownFnContext, - } - } else if m.DownFnNoTxContext != nil { - g.down = &GoMigrationFunc{ - RunNoTx: m.DownFnNoTxContext, - } - } - registered[version] = g + versionToGoMigration[version] = m } - migrations, err := merge(filesystemSources, registered) + migrations, err := merge(filesystemSources, versionToGoMigration) if err != nil { return nil, err } @@ -181,7 +152,11 @@ func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { func (p *Provider) ListSources() []Source { sources := make([]Source, 0, len(p.migrations)) for _, m := range p.migrations { - sources = append(sources, m.Source) + sources = append(sources, Source{ + Type: m.Type, + Path: m.Source, + Version: m.Version, + }) } return sources } diff --git a/internal/provider/collect.go b/provider_collect.go similarity index 75% rename from internal/provider/collect.go rename to provider_collect.go index 345da0d06..5f14fff13 100644 --- a/internal/provider/collect.go +++ b/provider_collect.go @@ -1,15 +1,12 @@ -package provider +package goose import ( "errors" "fmt" "io/fs" - "os" "path/filepath" "sort" "strings" - - "github.com/pressly/goose/v3" ) // fileSources represents a collection of migration files on the filesystem. @@ -18,25 +15,6 @@ type fileSources struct { goSources []Source } -// TODO(mf): remove? -func (s *fileSources) lookup(t MigrationType, version int64) *Source { - switch t { - case TypeGo: - for _, source := range s.goSources { - if source.Version == version { - return &source - } - } - case TypeSQL: - for _, source := range s.sqlSources { - if source.Version == version { - return &source - } - } - } - return nil -} - // collectFilesystemSources scans the file system for migration files that have a numeric prefix // (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may // be nil, in which case an empty fileSources is returned. @@ -46,7 +24,12 @@ func (s *fileSources) lookup(t MigrationType, version int64) *Source { // // This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects // migration sources from the filesystem. -func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) { +func collectFilesystemSources( + fsys fs.FS, + strict bool, + excludePaths map[string]bool, + excludeVersions map[int64]bool, +) (*fileSources, error) { if fsys == nil { return new(fileSources), nil } @@ -62,8 +45,11 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) } for _, fullpath := range files { base := filepath.Base(fullpath) - // Skip explicit excludes or Go test files. - if excludes[base] || strings.HasSuffix(base, "_test.go") { + if strings.HasSuffix(base, "_test.go") { + continue + } + if excludePaths[base] { + // TODO(mf): log this? continue } // If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use @@ -71,13 +57,17 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) // filenames, but still have versioned migrations within the same directory. For // example, a user could have a helpers.go file which contains unexported helper // functions for migrations. - version, err := goose.NumericComponent(base) + version, err := NumericComponent(base) if err != nil { if strict { return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) } continue } + if excludeVersions[version] { + // TODO: log this? + continue + } // Ensure there are no duplicate versions. if existing, ok := versionToBaseLookup[version]; ok { return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", @@ -101,7 +91,7 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) }) default: // Should never happen since we already filtered out all other file types. - return nil, fmt.Errorf("unknown migration type: %s", base) + return nil, fmt.Errorf("invalid file extension: %q", base) } // Add the version to the lookup map. versionToBaseLookup[version] = base @@ -110,15 +100,25 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) return sources, nil } -func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) { - var migrations []*migration - migrationLookup := make(map[int64]*migration) +func newSQLMigration(source Source) *Migration { + return &Migration{ + Type: source.Type, + Version: source.Version, + Source: source.Path, + construct: true, + Next: -1, Previous: -1, + sql: sqlMigration{ + Parsed: false, // SQL migrations are parsed lazily. + }, + } +} + +func merge(sources *fileSources, registerd map[int64]*Migration) ([]*Migration, error) { + var migrations []*Migration + migrationLookup := make(map[int64]*Migration) // Add all SQL migrations to the list of migrations. for _, source := range sources.sqlSources { - m := &migration{ - Source: source, - SQL: nil, // SQL migrations are parsed lazily. - } + m := newSQLMigration(source) migrations = append(migrations, m) migrationLookup[source.Version] = m } @@ -147,38 +147,24 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration // wholesale as part of migrations. This allows users to build a custom binary that only embeds // the SQL migration files. for version, r := range registerd { - fullpath := r.fullpath - if fullpath == "" { - if s := sources.lookup(TypeGo, version); s != nil { - fullpath = s.Path - } - } // Ensure there are no duplicate versions. if existing, ok := migrationLookup[version]; ok { - fullpath := r.fullpath + fullpath := r.Source if fullpath == "" { fullpath = "manually registered (no source)" } return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", version, - existing.Source.Path, + existing.Source, fullpath, ) } - m := &migration{ - Source: Source{ - Type: TypeGo, - Path: fullpath, // May be empty if migration was registered manually. - Version: version, - }, - Go: r, - } - migrations = append(migrations, m) - migrationLookup[version] = m + migrations = append(migrations, r) + migrationLookup[version] = r } // Sort migrations by version in ascending order. sort.Slice(migrations, func(i, j int) bool { - return migrations[i].Source.Version < migrations[j].Source.Version + return migrations[i].Version < migrations[j].Version }) return migrations, nil } @@ -203,11 +189,3 @@ func unregisteredError(unregistered []string) error { return errors.New(b.String()) } - -type noopFS struct{} - -var _ fs.FS = noopFS{} - -func (f noopFS) Open(name string) (fs.File, error) { - return nil, os.ErrNotExist -} diff --git a/internal/provider/collect_test.go b/provider_collect_test.go similarity index 77% rename from internal/provider/collect_test.go rename to provider_collect_test.go index e696ab005..d3ee08a10 100644 --- a/internal/provider/collect_test.go +++ b/provider_collect_test.go @@ -1,4 +1,4 @@ -package provider +package goose import ( "io/fs" @@ -12,21 +12,21 @@ import ( func TestCollectFileSources(t *testing.T) { t.Parallel() t.Run("nil_fsys", func(t *testing.T) { - sources, err := collectFilesystemSources(nil, false, nil) + sources, err := collectFilesystemSources(nil, false, nil, nil) check.NoError(t, err) check.Bool(t, sources != nil, true) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) }) t.Run("noop_fsys", func(t *testing.T) { - sources, err := collectFilesystemSources(noopFS{}, false, nil) + sources, err := collectFilesystemSources(noopFS{}, false, nil, nil) check.NoError(t, err) check.Bool(t, sources != nil, true) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) }) t.Run("empty_fsys", func(t *testing.T) { - sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil) + sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) @@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) { "00000_foo.sql": sqlMapFile, } // strict disable - should not error - sources, err := collectFilesystemSources(mapFS, false, nil) + sources, err := collectFilesystemSources(mapFS, false, nil, nil) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) // strict enabled - should error - _, err = collectFilesystemSources(mapFS, true, nil) + _, err = collectFilesystemSources(mapFS, true, nil, nil) check.HasError(t, err) check.Contains(t, err.Error(), "migration version must be greater than zero") }) t.Run("collect", func(t *testing.T) { fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 4) check.Number(t, len(sources.goSources), 0) @@ -76,6 +76,7 @@ func TestCollectFileSources(t *testing.T) { "00002_bar.sql": true, "00110_qux.sql": true, }, + nil, ) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 2) @@ -96,7 +97,7 @@ func TestCollectFileSources(t *testing.T) { mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")} fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - _, err = collectFilesystemSources(fsys, true, nil) + _, err = collectFilesystemSources(fsys, true, nil, nil) check.HasError(t, err) check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) }) @@ -108,7 +109,7 @@ func TestCollectFileSources(t *testing.T) { "4_qux.sql": sqlMapFile, "5_foo_test.go": {Data: []byte(`package goose_test`)}, } - sources, err := collectFilesystemSources(mapFS, false, nil) + sources, err := collectFilesystemSources(mapFS, false, nil, nil) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 4) check.Number(t, len(sources.goSources), 0) @@ -123,7 +124,7 @@ func TestCollectFileSources(t *testing.T) { "no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)}, "some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)}, } - sources, err := collectFilesystemSources(mapFS, false, nil) + sources, err := collectFilesystemSources(mapFS, false, nil, nil) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 2) check.Number(t, len(sources.goSources), 1) @@ -142,7 +143,7 @@ func TestCollectFileSources(t *testing.T) { "001_foo.sql": sqlMapFile, "01_bar.sql": sqlMapFile, } - _, err := collectFilesystemSources(mapFS, false, nil) + _, err := collectFilesystemSources(mapFS, false, nil, nil) check.HasError(t, err) check.Contains(t, err.Error(), "found duplicate migration version 1") }) @@ -158,7 +159,7 @@ func TestCollectFileSources(t *testing.T) { t.Helper() f, err := fs.Sub(mapFS, dirpath) check.NoError(t, err) - got, err := collectFilesystemSources(f, false, nil) + got, err := collectFilesystemSources(f, false, nil, nil) check.NoError(t, err) check.Number(t, len(got.sqlSources), len(sqlSources)) check.Number(t, len(got.goSources), 0) @@ -194,27 +195,21 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil) check.NoError(t, err) check.Equal(t, len(sources.sqlSources), 1) check.Equal(t, len(sources.goSources), 2) - src1 := sources.lookup(TypeSQL, 1) - check.Bool(t, src1 != nil, true) - src2 := sources.lookup(TypeGo, 2) - check.Bool(t, src2 != nil, true) - src3 := sources.lookup(TypeGo, 3) - check.Bool(t, src3 != nil, true) - t.Run("valid", func(t *testing.T) { - migrations, err := merge(sources, map[int64]*goMigration{ - 2: newGoMigration("", nil, nil), - 3: newGoMigration("", nil, nil), - }) + registered := map[int64]*Migration{ + 2: NewGoMigration(2, nil, nil), + 3: NewGoMigration(3, nil, nil), + } + migrations, err := merge(sources, registered) check.NoError(t, err) check.Number(t, len(migrations), 3) assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) - assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3)) + assertMigration(t, migrations[1], newSource(TypeGo, "", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) }) t.Run("unregistered_all", func(t *testing.T) { _, err := merge(sources, nil) @@ -224,18 +219,16 @@ func TestMerge(t *testing.T) { check.Contains(t, err.Error(), "00003_baz.go") }) t.Run("unregistered_some", func(t *testing.T) { - _, err := merge(sources, map[int64]*goMigration{ - 2: newGoMigration("", nil, nil), - }) + _, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)}) check.HasError(t, err) check.Contains(t, err.Error(), "error: detected 1 unregistered Go file") check.Contains(t, err.Error(), "00003_baz.go") }) t.Run("duplicate_sql", func(t *testing.T) { - _, err := merge(sources, map[int64]*goMigration{ - 1: newGoMigration("", nil, nil), // duplicate. SQL already exists. - 2: newGoMigration("", nil, nil), - 3: newGoMigration("", nil, nil), + _, err := merge(sources, map[int64]*Migration{ + 1: NewGoMigration(1, nil, nil), // duplicate. SQL already exists. + 2: NewGoMigration(2, nil, nil), + 3: NewGoMigration(3, nil, nil), }) check.HasError(t, err) check.Contains(t, err.Error(), "found duplicate migration version 1") @@ -250,13 +243,13 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil) check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { - migrations, err := merge(sources, map[int64]*goMigration{ - 3: newGoMigration("", nil, nil), + migrations, err := merge(sources, map[int64]*Migration{ + 3: NewGoMigration(3, nil, nil), // 4 is missing - 6: newGoMigration("", nil, nil), + 6: NewGoMigration(6, nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 5) @@ -274,20 +267,20 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil) check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { - migrations, err := merge(sources, map[int64]*goMigration{ + migrations, err := merge(sources, map[int64]*Migration{ // This is the only Go file on disk. - 2: newGoMigration("", nil, nil), + 2: NewGoMigration(2, nil, nil), // These are not on disk. Explicitly registered. - 3: newGoMigration("", nil, nil), - 6: newGoMigration("", nil, nil), + 3: NewGoMigration(3, nil, nil), + 6: NewGoMigration(6, nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 4) assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[1], newSource(TypeGo, "", 2)) assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) assertMigration(t, migrations[3], newSource(TypeGo, "", 6)) }) @@ -308,15 +301,15 @@ func TestCheckMissingMigrations(t *testing.T) { {Version: 5}, {Version: 7}, // <-- database max version_id } - fsMigrations := []*migration{ - newMigrationVersion(1), - newMigrationVersion(2), // missing migration - newMigrationVersion(3), - newMigrationVersion(4), - newMigrationVersion(5), - newMigrationVersion(6), // missing migration - newMigrationVersion(7), // ----- database max version_id ----- - newMigrationVersion(8), // new migration + fsMigrations := []*Migration{ + newSQLMigration(Source{Version: 1}), + newSQLMigration(Source{Version: 2}), // missing migration + newSQLMigration(Source{Version: 3}), + newSQLMigration(Source{Version: 4}), + newSQLMigration(Source{Version: 5}), + newSQLMigration(Source{Version: 6}), // missing migration + newSQLMigration(Source{Version: 7}), // ----- database max version_id ----- + newSQLMigration(Source{Version: 8}), // new migration } got := checkMissingMigrations(dbMigrations, fsMigrations) check.Number(t, len(got), 2) @@ -334,9 +327,9 @@ func TestCheckMissingMigrations(t *testing.T) { {Version: 5}, {Version: 2}, } - fsMigrations := []*migration{ - newMigrationVersion(3), // new migration - newMigrationVersion(4), // new migration + fsMigrations := []*Migration{ + NewGoMigration(3, nil, nil), // new migration + NewGoMigration(4, nil, nil), // new migration } got := checkMissingMigrations(dbMigrations, fsMigrations) check.Number(t, len(got), 2) @@ -345,24 +338,19 @@ func TestCheckMissingMigrations(t *testing.T) { }) } -func newMigrationVersion(version int64) *migration { - return &migration{ - Source: Source{ - Version: version, - }, - } -} - -func assertMigration(t *testing.T, got *migration, want Source) { +func assertMigration(t *testing.T, got *Migration, want Source) { t.Helper() - check.Equal(t, got.Source, want) - switch got.Source.Type { + check.Equal(t, got.Type, want.Type) + check.Equal(t, got.Version, want.Version) + check.Equal(t, got.Source, want.Path) + switch got.Type { case TypeGo: - check.Bool(t, got.Go != nil, true) + check.Bool(t, got.goUp != nil, true) + check.Bool(t, got.goDown != nil, true) case TypeSQL: - check.Bool(t, got.SQL == nil, true) + check.Bool(t, got.sql.Parsed, false) default: - t.Fatalf("unknown migration type: %s", got.Source.Type) + t.Fatalf("unknown migration type: %s", got.Type) } } diff --git a/internal/provider/errors.go b/provider_errors.go similarity index 89% rename from internal/provider/errors.go rename to provider_errors.go index 16cdd3fb7..ecaf89801 100644 --- a/internal/provider/errors.go +++ b/provider_errors.go @@ -1,4 +1,4 @@ -package provider +package goose import ( "errors" @@ -15,9 +15,6 @@ var ( // ErrNoMigrations is returned by [NewProvider] when no migrations are found. ErrNoMigrations = errors.New("no migrations found") - - // ErrNoNextVersion when the next migration version is not found. - ErrNoNextVersion = errors.New("no next version found") ) // PartialError is returned when a migration fails, but some migrations already got applied. diff --git a/provider_impl.go b/provider_impl.go new file mode 100644 index 000000000..3e838d304 --- /dev/null +++ b/provider_impl.go @@ -0,0 +1,206 @@ +package goose + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sort" + + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +func (p *Provider) up( + ctx context.Context, + upByOne bool, + version int64, +) (_ []*MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + if len(p.migrations) == 0 { + return nil, nil + } + var apply []*Migration + if p.cfg.disableVersioning { + apply = p.migrations + } else { + // optimize(mf): Listing all migrations from the database isn't great. This is only required + // to support the allow missing (out-of-order) feature. For users that don't use this + // feature, we could just query the database for the current max version and then apply + // migrations greater than that version. + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if len(dbMigrations) == 0 { + return nil, errMissingZeroVersion + } + apply, err = p.resolveUpMigrations(dbMigrations, version) + if err != nil { + return nil, err + } + } + // feat(mf): this is where can (optionally) group multiple migrations to be run in a single + // transaction. The default is to apply each migration sequentially on its own. + // https://github.com/pressly/goose/issues/222 + // + // Careful, we can't use a single transaction for all migrations because some may have to be run + // in their own transaction. + return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne) +} + +func (p *Provider) down( + ctx context.Context, + downByOne bool, + version int64, +) (_ []*MigrationResult, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + if len(p.migrations) == 0 { + return nil, nil + } + if p.cfg.disableVersioning { + downMigrations := p.migrations + if downByOne { + last := p.migrations[len(p.migrations)-1] + downMigrations = []*Migration{last} + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + } + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if len(dbMigrations) == 0 { + return nil, errMissingZeroVersion + } + if dbMigrations[0].Version == 0 { + return nil, nil + } + var downMigrations []*Migration + for _, dbMigration := range dbMigrations { + if dbMigration.Version <= version { + break + } + m, err := p.getMigration(dbMigration.Version) + if err != nil { + return nil, err + } + downMigrations = append(downMigrations, m) + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) +} + +func (p *Provider) apply( + ctx context.Context, + version int64, + direction bool, +) (_ *MigrationResult, retErr error) { + m, err := p.getMigration(version) + if err != nil { + return nil, err + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + result, err := p.store.GetMigration(ctx, conn, version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + // If the migration has already been applied, return an error, unless the migration is being + // applied in the opposite direction. In that case, we allow the migration to be applied again. + if result != nil && direction { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + + d := sqlparser.DirectionDown + if direction { + d = sqlparser.DirectionUp + } + results, err := p.runMigrations(ctx, conn, []*Migration{m}, d, true) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + return results[0], nil +} + +func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to + // support limiting the set. + + status := make([]*MigrationStatus, 0, len(p.migrations)) + for _, m := range p.migrations { + migrationStatus := &MigrationStatus{ + Source: Source{ + Type: m.Type, + Path: m.Source, + Version: m.Version, + }, + State: StatePending, + } + dbResult, err := p.store.GetMigration(ctx, conn, m.Version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + if dbResult != nil { + migrationStatus.State = StateApplied + migrationStatus.AppliedAt = dbResult.Timestamp + } + status = append(status, migrationStatus) + } + + return status, nil +} + +func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return 0, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + res, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return 0, err + } + if len(res) == 0 { + return 0, nil + } + sort.Slice(res, func(i, j int) bool { + return res[i].Version > res[j].Version + }) + return res[0].Version, nil +} diff --git a/provider_migrate.go b/provider_migrate.go new file mode 100644 index 000000000..1386f7406 --- /dev/null +++ b/provider_migrate.go @@ -0,0 +1,90 @@ +package goose + +import ( + "context" + "database/sql" + "fmt" + + "github.com/pressly/goose/v3/database" +) + +func (m *Migration) useTx(direction bool) bool { + switch m.Type { + case TypeGo: + if direction && m.goUp.Mode == TransactionEnabled { + return true + } + if !direction && m.goDown.Mode == TransactionEnabled { + return true + } + return false + case TypeSQL: + return m.sql.UseTx + } + // This should never happen. + panic(fmt.Sprintf("invalid migration type: %q", m.Type)) +} + +func (m *Migration) isEmpty(direction bool) bool { + switch m.Type { + case TypeGo: + if direction { + return m.goUp.RunTx == nil && m.goUp.RunDB == nil + } + return m.goDown.RunTx == nil && m.goDown.RunDB == nil + case TypeSQL: + if direction { + return len(m.sql.Up) == 0 + } + return len(m.sql.Down) == 0 + } + // This should never happen. + panic(fmt.Sprintf("invalid migration type: %q", m.Type)) +} + +func (m *Migration) apply(ctx context.Context, db database.DBTxConn, direction bool) error { + switch m.Type { + case TypeGo: + return runGo(ctx, db, m, direction) + case TypeSQL: + if direction { + return runSQL(ctx, db, m.sql.Up) + } + return runSQL(ctx, db, m.sql.Down) + } + // This should never happen. + panic(fmt.Sprintf("invalid migration type: %q", m.Type)) +} + +func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { + switch db := db.(type) { + case *sql.Conn: + return fmt.Errorf("go migrations are not supported with *sql.Conn") + case *sql.DB: + if direction && m.goUp.RunDB != nil { + return m.goUp.RunDB(ctx, db) + } + if !direction && m.goDown.RunDB != nil { + return m.goDown.RunDB(ctx, db) + } + return nil + case *sql.Tx: + if direction && m.goUp.RunTx != nil { + return m.goUp.RunTx(ctx, db) + } + if !direction && m.goDown.RunTx != nil { + return m.goDown.RunTx(ctx, db) + } + return nil + } + return fmt.Errorf("invalid database connection type: %T", db) +} + +func runSQL(ctx context.Context, db database.DBTxConn, statements []string) error { + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} diff --git a/internal/provider/provider_options.go b/provider_options.go similarity index 63% rename from internal/provider/provider_options.go rename to provider_options.go index dd29ee4a9..03c6f7c62 100644 --- a/internal/provider/provider_options.go +++ b/provider_options.go @@ -1,8 +1,6 @@ -package provider +package goose import ( - "context" - "database/sql" "errors" "fmt" @@ -12,12 +10,11 @@ import ( const ( // DefaultTablename is the default name of the database table used to track history of applied - // migrations. It can be overridden using the [WithTableName] option when creating a new - // provider. + // migrations. It can be overridden using the [WithTableName] option when creating a new goose. DefaultTablename = "goose_db_version" ) -// ProviderOption is a configuration option for a goose provider. +// ProviderOption is a configuration option for a goose goose. type ProviderOption interface { apply(*config) error } @@ -84,60 +81,48 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption { }) } -// WithExcludes excludes the given file names from the list of migrations. -// -// If WithExcludes is called multiple times, the list of excludes is merged. -func WithExcludes(excludes []string) ProviderOption { +// WithExcludeNames excludes the given file name from the list of migrations. If called multiple +// times, the list of excludes is merged. +func WithExcludeNames(excludes []string) ProviderOption { return configFunc(func(c *config) error { for _, name := range excludes { - c.excludes[name] = true + if _, ok := c.excludePaths[name]; ok { + return fmt.Errorf("duplicate exclude file name: %s", name) + } + c.excludePaths[name] = true } return nil }) } -// GoMigrationFunc is a user-defined Go migration, registered using the option [WithGoMigration]. -type GoMigrationFunc struct { - // One of the following must be set: - Run func(context.Context, *sql.Tx) error - // -- OR -- - RunNoTx func(context.Context, *sql.DB) error +// WithExcludeVersions excludes the given versions from the list of migrations. If called multiple +// times, the list of excludes is merged. +func WithExcludeVersions(versions []int64) ProviderOption { + return configFunc(func(c *config) error { + for _, version := range versions { + if _, ok := c.excludeVersions[version]; ok { + return fmt.Errorf("duplicate excludes version: %d", version) + } + c.excludeVersions[version] = true + } + return nil + }) } -// WithGoMigration registers a Go migration with the given version. +// WithGoMigrations registers Go migrations with the provider. If a Go migration with the same +// version has already been registered, an error will be returned. // -// If WithGoMigration is called multiple times with the same version, an error is returned. Both up -// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be -// set. -func WithGoMigration(version int64, up, down *GoMigrationFunc) ProviderOption { +// Go migrations must be constructed using the [NewGoMigration] function. +func WithGoMigrations(migrations ...*Migration) ProviderOption { return configFunc(func(c *config) error { - if version < 1 { - return errors.New("version must be greater than zero") - } - if _, ok := c.registered[version]; ok { - return fmt.Errorf("go migration with version %d already registered", version) - } - // Allow nil up/down functions. This enables users to apply "no-op" migrations, while - // versioning them. - if up != nil { - if up.Run == nil && up.RunNoTx == nil { - return fmt.Errorf("go migration with version %d must have an up function", version) - } - if up.Run != nil && up.RunNoTx != nil { - return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version) + for _, m := range migrations { + if _, ok := c.registered[m.Version]; ok { + return fmt.Errorf("go migration with version %d already registered", m.Version) } - } - if down != nil { - if down.Run == nil && down.RunNoTx == nil { - return fmt.Errorf("go migration with version %d must have a down function", version) - } - if down.Run != nil && down.RunNoTx != nil { - return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version) + if err := checkGoMigration(m); err != nil { + return fmt.Errorf("invalid go migration: %w", err) } - } - c.registered[version] = &goMigration{ - up: up, - down: down, + c.registered[m.Version] = m } return nil }) @@ -171,12 +156,13 @@ func WithDisabledVersioning(b bool) ProviderOption { type config struct { store database.Store - verbose bool - excludes map[string]bool + verbose bool + excludePaths map[string]bool + excludeVersions map[int64]bool - // Go migrations registered by the user. These will be merged/resolved with migrations from the - // filesystem and init() functions. - registered map[int64]*goMigration + // Go migrations registered by the user. These will be merged/resolved against the globally + // registered migrations. + registered map[int64]*Migration // Locking options lockEnabled bool diff --git a/internal/provider/provider_options_test.go b/provider_options_test.go similarity index 61% rename from internal/provider/provider_options_test.go rename to provider_options_test.go index e524b63ab..0575cf417 100644 --- a/internal/provider/provider_options_test.go +++ b/provider_options_test.go @@ -1,4 +1,4 @@ -package provider_test +package goose_test import ( "database/sql" @@ -6,9 +6,9 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/provider" _ "modernc.org/sqlite" ) @@ -24,45 +24,45 @@ func TestNewProvider(t *testing.T) { } t.Run("invalid", func(t *testing.T) { // Empty dialect not allowed - _, err = provider.NewProvider("", db, fsys) + _, err = goose.NewProvider("", db, fsys) check.HasError(t, err) // Invalid dialect not allowed - _, err = provider.NewProvider("unknown-dialect", db, fsys) + _, err = goose.NewProvider("unknown-dialect", db, fsys) check.HasError(t, err) // Nil db not allowed - _, err = provider.NewProvider(database.DialectSQLite3, nil, fsys) + _, err = goose.NewProvider(database.DialectSQLite3, nil, fsys) check.HasError(t, err) // Nil fsys not allowed - _, err = provider.NewProvider(database.DialectSQLite3, db, nil) + _, err = goose.NewProvider(database.DialectSQLite3, db, nil) check.HasError(t, err) // Nil store not allowed - _, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(nil)) + _, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(nil)) check.HasError(t, err) // Cannot set both dialect and store store, err := database.NewStore(database.DialectSQLite3, "custom_table") check.NoError(t, err) - _, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(store)) + _, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(store)) check.HasError(t, err) // Multiple stores not allowed - _, err = provider.NewProvider(database.DialectSQLite3, db, nil, - provider.WithStore(store), - provider.WithStore(store), + _, err = goose.NewProvider(database.DialectSQLite3, db, nil, + goose.WithStore(store), + goose.WithStore(store), ) check.HasError(t, err) }) t.Run("valid", func(t *testing.T) { // Valid dialect, db, and fsys allowed - _, err = provider.NewProvider(database.DialectSQLite3, db, fsys) + _, err = goose.NewProvider(database.DialectSQLite3, db, fsys) check.NoError(t, err) // Valid dialect, db, fsys, and verbose allowed - _, err = provider.NewProvider(database.DialectSQLite3, db, fsys, - provider.WithVerbose(testing.Verbose()), + _, err = goose.NewProvider(database.DialectSQLite3, db, fsys, + goose.WithVerbose(testing.Verbose()), ) check.NoError(t, err) // Custom store allowed store, err := database.NewStore(database.DialectSQLite3, "custom_table") check.NoError(t, err) - _, err = provider.NewProvider("", db, nil, provider.WithStore(store)) + _, err = goose.NewProvider("", db, nil, goose.WithStore(store)) check.HasError(t, err) }) } diff --git a/provider_run.go b/provider_run.go new file mode 100644 index 000000000..10cc63364 --- /dev/null +++ b/provider_run.go @@ -0,0 +1,356 @@ +package goose + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "sort" + "strings" + "time" + + "github.com/pressly/goose/v3/database" + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +var ( + errMissingZeroVersion = errors.New("missing zero version migration") +) + +func (p *Provider) resolveUpMigrations( + dbVersions []*database.ListMigrationsResult, + version int64, +) ([]*Migration, error) { + var apply []*Migration + var dbMaxVersion int64 + // dbAppliedVersions is a map of all applied migrations in the database. + dbAppliedVersions := make(map[int64]bool, len(dbVersions)) + for _, m := range dbVersions { + dbAppliedVersions[m.Version] = true + if m.Version > dbMaxVersion { + dbMaxVersion = m.Version + } + } + missingMigrations := checkMissingMigrations(dbVersions, p.migrations) + // feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing + // migrations entirely. At the moment this is not supported, but leaving this comment because + // that's where that logic would be handled. + // + // For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not + // sure if this is a common use case, but it's possible. + if len(missingMigrations) > 0 && !p.cfg.allowMissing { + var collected []string + for _, v := range missingMigrations { + collected = append(collected, fmt.Sprintf("%d", v.versionID)) + } + msg := "migration" + if len(collected) > 1 { + msg += "s" + } + return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]", + len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","), + ) + } + for _, v := range missingMigrations { + m, err := p.getMigration(v.versionID) + if err != nil { + return nil, err + } + apply = append(apply, m) + } + // filter all migrations with a version greater than the supplied version (min) and less than or + // equal to the requested version (max). Skip any migrations that have already been applied. + for _, m := range p.migrations { + if dbAppliedVersions[m.Version] { + continue + } + if m.Version > dbMaxVersion && m.Version <= version { + apply = append(apply, m) + } + } + return apply, nil +} + +func (p *Provider) prepareMigration(fsys fs.FS, m *Migration, direction bool) error { + switch m.Type { + case TypeGo: + if m.goUp.Mode == 0 { + return errors.New("go up migration mode is not set") + } + if m.goDown.Mode == 0 { + return errors.New("go down migration mode is not set") + } + var useTx bool + if direction { + useTx = m.goUp.Mode == TransactionEnabled + } else { + useTx = m.goDown.Mode == TransactionEnabled + } + // bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, + // but are locking the database with *sql.Conn. If the caller sets max open connections to + // 1, then this will deadlock because the Go migration will try to acquire a connection from + // the pool, but the pool is exhausted because the lock is held. + // + // A potential solution is to expose a third Go register function *sql.Conn. Or continue to + // use *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is + // a bit of an edge case. For now, we guard against this scenario by checking the max open + // connections and returning an error. + if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 { + if !useTx { + return errors.New("potential deadlock detected: cannot run Go migration without a transaction when max open connections set to 1") + } + } + return nil + case TypeSQL: + if m.sql.Parsed { + return nil + } + parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source, false) + if err != nil { + return err + } + m.sql.Parsed = true + m.sql.UseTx = parsed.UseTx + m.sql.Up, m.sql.Down = parsed.Up, parsed.Down + return nil + } + return fmt.Errorf("invalid migration type: %+v", m) +} + +// runMigrations runs migrations sequentially in the given direction. If the migrations list is +// empty, return nil without error. +func (p *Provider) runMigrations( + ctx context.Context, + conn *sql.Conn, + migrations []*Migration, + direction sqlparser.Direction, + byOne bool, +) ([]*MigrationResult, error) { + if len(migrations) == 0 { + return nil, nil + } + apply := migrations + if byOne { + apply = migrations[:1] + } + + // SQL migrations are lazily parsed in both directions. This is done before attempting to run + // any migrations to catch errors early and prevent leaving the database in an incomplete state. + + for _, m := range apply { + if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil { + return nil, err + } + } + + // feat(mf): If we decide to add support for advisory locks at the transaction level, this may + // be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe + // to run in a transaction. + + // feat(mf): this is where we can (optionally) group multiple migrations to be run in a single + // transaction. The default is to apply each migration sequentially on its own. See the + // following issues for more details: + // - https://github.com/pressly/goose/issues/485 + // - https://github.com/pressly/goose/issues/222 + + var results []*MigrationResult + for _, m := range apply { + current := &MigrationResult{ + Source: Source{ + Type: m.Type, + Path: m.Source, + Version: m.Version, + }, + Direction: direction.String(), + Empty: m.isEmpty(direction.ToBool()), + } + start := time.Now() + if err := p.runIndividually(ctx, conn, m, direction.ToBool()); err != nil { + // TODO(mf): we should also return the pending migrations here, the remaining items in + // the apply slice. + current.Error = err + current.Duration = time.Since(start) + return nil, &PartialError{ + Applied: results, + Failed: current, + Err: err, + } + } + current.Duration = time.Since(start) + results = append(results, current) + } + + return results, nil +} + +func (p *Provider) runIndividually( + ctx context.Context, + conn *sql.Conn, + m *Migration, + direction bool, +) error { + if m.useTx(direction) { + return beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := m.apply(ctx, tx, direction); err != nil { + return err + } + return p.maybeInsertOrDelete(ctx, tx, m.Version, direction) + }) + } + switch m.Type { + case TypeGo: + // Note, we are using *sql.DB instead of *sql.Conn because it's the Go migration contract. + // This may be a deadlock scenario if max open connections is set to 1 AND a lock is + // acquired on the database. In this case, the migration will block forever unable to + // acquire a connection from the pool. + // + // TODO(mf): we can detect this scenario and return a more helpful error message. + if err := m.apply(ctx, p.db, direction); err != nil { + return err + } + return p.maybeInsertOrDelete(ctx, p.db, m.Version, direction) + case TypeSQL: + if err := m.apply(ctx, conn, direction); err != nil { + return err + } + return p.maybeInsertOrDelete(ctx, conn, m.Version, direction) + } + // + // This should never happen!! + // + return fmt.Errorf("failed to run individual migration: neither sql or go: %v", m) +} + +func (p *Provider) maybeInsertOrDelete( + ctx context.Context, + db database.DBTxConn, + version int64, + direction bool, +) error { + // If versioning is disabled, we don't need to insert or delete the migration version. + if p.cfg.disableVersioning { + return nil + } + if direction { + return p.store.Insert(ctx, db, database.InsertRequest{Version: version}) + } + return p.store.Delete(ctx, db, version) +} + +// beginTx begins a transaction and runs the given function. If the function returns an error, the +// transaction is rolled back. Otherwise, the transaction is committed. +func beginTx(ctx context.Context, conn *sql.Conn, fn func(tx *sql.Tx) error) (retErr error) { + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) + } + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) { + p.mu.Lock() + conn, err := p.db.Conn(ctx) + if err != nil { + p.mu.Unlock() + return nil, nil, err + } + // cleanup is a function that cleans up the connection, and optionally, the session lock. + cleanup := func() error { + p.mu.Unlock() + return conn.Close() + } + if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled { + if err := l.SessionLock(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + cleanup = func() error { + p.mu.Unlock() + // Use a detached context to unlock the session. This is because the context passed to + // SessionLock may have been canceled, and we don't want to cancel the unlock. + // + // TODO(mf): use [context.WithoutCancel] added in go1.21 + detachedCtx := context.Background() + return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) + } + } + // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't + // need the version table because there is no versioning. + if !p.cfg.disableVersioning { + if err := p.ensureVersionTable(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + } + return conn, cleanup, nil +} + +func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { + // feat(mf): this is where we can check if the version table exists instead of trying to fetch + // from a table that may not exist. https://github.com/pressly/goose/issues/461 + res, err := p.store.GetMigration(ctx, conn, 0) + if err == nil && res != nil { + return nil + } + return beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := p.store.CreateVersionTable(ctx, tx); err != nil { + return err + } + if p.cfg.disableVersioning { + return nil + } + return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0}) + }) +} + +type missingMigration struct { + versionID int64 +} + +// checkMissingMigrations returns a list of migrations that are missing from the database. A missing +// migration is one that has a version less than the max version in the database. +func checkMissingMigrations( + dbMigrations []*database.ListMigrationsResult, + fsMigrations []*Migration, +) []missingMigration { + existing := make(map[int64]bool) + var dbMaxVersion int64 + for _, m := range dbMigrations { + existing[m.Version] = true + if m.Version > dbMaxVersion { + dbMaxVersion = m.Version + } + } + var missing []missingMigration + for _, m := range fsMigrations { + version := m.Version + if !existing[version] && version < dbMaxVersion { + missing = append(missing, missingMigration{ + versionID: version, + }) + } + } + sort.Slice(missing, func(i, j int) bool { + return missing[i].versionID < missing[j].versionID + }) + return missing +} + +// getMigration returns the migration with the given version. If no migration is found, then +// ErrVersionNotFound is returned. +func (p *Provider) getMigration(version int64) (*Migration, error) { + for _, m := range p.migrations { + if m.Version == version { + return m, nil + } + } + return nil, ErrVersionNotFound +} diff --git a/internal/provider/run_test.go b/provider_run_test.go similarity index 77% rename from internal/provider/run_test.go rename to provider_run_test.go index 4d12a0db7..bbb018d4b 100644 --- a/internal/provider/run_test.go +++ b/provider_run_test.go @@ -1,4 +1,4 @@ -package provider_test +package goose_test import ( "context" @@ -16,9 +16,9 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/provider" "github.com/pressly/goose/v3/internal/testdb" "github.com/pressly/goose/v3/lock" "golang.org/x/sync/errgroup" @@ -45,10 +45,10 @@ func TestProviderRun(t *testing.T) { p, _ := newProviderWithDB(t) _, err := p.ApplyVersion(context.Background(), 999, true) check.HasError(t, err) - check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true) _, err = p.ApplyVersion(context.Background(), 999, false) check.HasError(t, err) - check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true) }) t.Run("run_zero", func(t *testing.T) { p, _ := newProviderWithDB(t) @@ -72,30 +72,30 @@ func TestProviderRun(t *testing.T) { check.Number(t, len(sources), numCount) // Ensure only SQL migrations are returned for _, s := range sources { - check.Equal(t, s.Type, provider.TypeSQL) + check.Equal(t, s.Type, goose.TypeSQL) } // Test Up res, err := p.Up(ctx) check.NoError(t, err) check.Number(t, len(res), numCount) - assertResult(t, res[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) - assertResult(t, res[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) - assertResult(t, res[2], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false) - assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false) - assertResult(t, res[4], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false) - assertResult(t, res[5], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true) - assertResult(t, res[6], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true) + assertResult(t, res[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, res[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false) + assertResult(t, res[2], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "up", false) + assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "up", false) + assertResult(t, res[4], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "up", false) + assertResult(t, res[5], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "up", true) + assertResult(t, res[6], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "up", true) // Test Down res, err = p.DownTo(ctx, 0) check.NoError(t, err) check.Number(t, len(res), numCount) - assertResult(t, res[0], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true) - assertResult(t, res[1], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true) - assertResult(t, res[2], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false) - assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false) - assertResult(t, res[4], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false) - assertResult(t, res[5], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false) - assertResult(t, res[6], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false) + assertResult(t, res[0], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "down", true) + assertResult(t, res[1], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "down", true) + assertResult(t, res[2], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "down", false) + assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "down", false) + assertResult(t, res[4], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "down", false) + assertResult(t, res[5], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "down", false) + assertResult(t, res[6], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "down", false) }) t.Run("up_and_down_by_one", func(t *testing.T) { ctx := context.Background() @@ -107,8 +107,8 @@ func TestProviderRun(t *testing.T) { res, err := p.UpByOne(ctx) counter++ if counter > maxVersion { - if !errors.Is(err, provider.ErrNoNextVersion) { - t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + if !errors.Is(err, goose.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion) } break } @@ -126,8 +126,8 @@ func TestProviderRun(t *testing.T) { res, err := p.Down(ctx) counter++ if counter > maxVersion { - if !errors.Is(err, provider.ErrNoNextVersion) { - t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + if !errors.Is(err, goose.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion) } break } @@ -149,14 +149,14 @@ func TestProviderRun(t *testing.T) { results, err := p.UpTo(ctx, upToVersion) check.NoError(t, err) check.Number(t, len(results), upToVersion) - assertResult(t, results[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) - assertResult(t, results[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) + assertResult(t, results[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, results[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false) // Fetch the goose version from DB currentVersion, err := p.GetDBVersion(ctx) check.NoError(t, err) check.Number(t, currentVersion, upToVersion) // Validate the version actually matches what goose claims it is - gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + gotVersion, err := getMaxVersionID(db, goose.DefaultTablename) check.NoError(t, err) check.Number(t, gotVersion, upToVersion) }) @@ -197,7 +197,7 @@ func TestProviderRun(t *testing.T) { check.NoError(t, err) check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version) // Validate the db migration version actually matches what goose claims it is - gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + gotVersion, err := getMaxVersionID(db, goose.DefaultTablename) check.NoError(t, err) check.Number(t, gotVersion, currentVersion) tables, err := getTableNames(db) @@ -213,13 +213,13 @@ func TestProviderRun(t *testing.T) { downResult, err := p.DownTo(ctx, 0) check.NoError(t, err) check.Number(t, len(downResult), len(sources)) - gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + gotVersion, err := getMaxVersionID(db, goose.DefaultTablename) check.NoError(t, err) check.Number(t, gotVersion, 0) // Should only be left with a single table, the default goose table tables, err := getTableNames(db) check.NoError(t, err) - knownTables := []string{provider.DefaultTablename, "sqlite_sequence"} + knownTables := []string{goose.DefaultTablename, "sqlite_sequence"} if !reflect.DeepEqual(tables, knownTables) { t.Logf("got tables: %v", tables) t.Logf("known tables: %v", knownTables) @@ -261,7 +261,7 @@ func TestProviderRun(t *testing.T) { check.NoError(t, err) _, err = p.ApplyVersion(ctx, 1, true) check.HasError(t, err) - check.Bool(t, errors.Is(err, provider.ErrAlreadyApplied), true) + check.Bool(t, errors.Is(err, goose.ErrAlreadyApplied), true) check.Contains(t, err.Error(), "version 1: already applied") }) t.Run("status", func(t *testing.T) { @@ -272,26 +272,26 @@ func TestProviderRun(t *testing.T) { status, err := p.Status(ctx) check.NoError(t, err) check.Number(t, len(status), numCount) - assertStatus(t, status[0], provider.StatePending, newSource(provider.TypeSQL, "00001_users_table.sql", 1), true) - assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), true) - assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), true) - assertStatus(t, status[3], provider.StatePending, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), true) - assertStatus(t, status[4], provider.StatePending, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), true) - assertStatus(t, status[5], provider.StatePending, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), true) - assertStatus(t, status[6], provider.StatePending, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true) + assertStatus(t, status[0], goose.StatePending, newSource(goose.TypeSQL, "00001_users_table.sql", 1), true) + assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), true) + assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), true) + assertStatus(t, status[3], goose.StatePending, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), true) + assertStatus(t, status[4], goose.StatePending, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), true) + assertStatus(t, status[5], goose.StatePending, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), true) + assertStatus(t, status[6], goose.StatePending, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), true) // Apply all migrations _, err = p.Up(ctx) check.NoError(t, err) status, err = p.Status(ctx) check.NoError(t, err) check.Number(t, len(status), numCount) - assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false) - assertStatus(t, status[1], provider.StateApplied, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), false) - assertStatus(t, status[2], provider.StateApplied, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), false) - assertStatus(t, status[3], provider.StateApplied, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), false) - assertStatus(t, status[4], provider.StateApplied, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), false) - assertStatus(t, status[5], provider.StateApplied, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), false) - assertStatus(t, status[6], provider.StateApplied, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false) + assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], goose.StateApplied, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), false) + assertStatus(t, status[2], goose.StateApplied, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), false) + assertStatus(t, status[3], goose.StateApplied, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), false) + assertStatus(t, status[4], goose.StateApplied, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), false) + assertStatus(t, status[5], goose.StateApplied, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), false) + assertStatus(t, status[6], goose.StateApplied, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), false) }) t.Run("tx_partial_errors", func(t *testing.T) { countOwners := func(db *sql.DB) (int, error) { @@ -321,22 +321,22 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-2'); INSERT INTO owners (owner_name) VALUES ('seed-user-3'); `), } - p, err := provider.NewProvider(database.DialectSQLite3, db, mapFS) + p, err := goose.NewProvider(database.DialectSQLite3, db, mapFS) check.NoError(t, err) _, err = p.Up(ctx) check.HasError(t, err) check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)") - var expected *provider.PartialError + var expected *goose.PartialError check.Bool(t, errors.As(err, &expected), true) // Check Err field check.Bool(t, expected.Err != nil, true) check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") // Check Results field check.Number(t, len(expected.Applied), 1) - assertResult(t, expected.Applied[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, expected.Applied[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false) // Check Failed field check.Bool(t, expected.Failed != nil, true) - assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2) + assertSource(t, expected.Failed.Source, goose.TypeSQL, "00002_partial_error.sql", 2) check.Bool(t, expected.Failed.Empty, false) check.Bool(t, expected.Failed.Error != nil, true) check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)") @@ -351,9 +351,9 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3'); status, err := p.Status(ctx) check.NoError(t, err) check.Number(t, len(status), 3) - assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false) - assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_partial_error.sql", 2), true) - assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_insert_data.sql", 3), true) + assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_partial_error.sql", 2), true) + assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_insert_data.sql", 3), true) }) } @@ -415,7 +415,7 @@ func TestConcurrentProvider(t *testing.T) { check.NoError(t, err) check.Number(t, currentVersion, maxVersion) - ch := make(chan []*provider.MigrationResult) + ch := make(chan []*goose.MigrationResult) var wg sync.WaitGroup for i := 0; i < maxVersion; i++ { wg.Add(1) @@ -435,8 +435,8 @@ func TestConcurrentProvider(t *testing.T) { close(ch) }() var ( - valid [][]*provider.MigrationResult - empty [][]*provider.MigrationResult + valid [][]*goose.MigrationResult + empty [][]*goose.MigrationResult ) for results := range ch { if len(results) == 0 { @@ -486,9 +486,9 @@ func TestNoVersioning(t *testing.T) { // These are owners created by migration files. wantOwnerCount = 4 ) - p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, - provider.WithVerbose(testing.Verbose()), - provider.WithDisabledVersioning(false), // This is the default. + p, err := goose.NewProvider(database.DialectSQLite3, db, fsys, + goose.WithVerbose(testing.Verbose()), + goose.WithDisabledVersioning(false), // This is the default. ) check.Number(t, len(p.ListSources()), 3) check.NoError(t, err) @@ -499,9 +499,9 @@ func TestNoVersioning(t *testing.T) { check.Number(t, baseVersion, 3) t.Run("seed-up-down-to-zero", func(t *testing.T) { fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) - p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, - provider.WithVerbose(testing.Verbose()), - provider.WithDisabledVersioning(true), // Provider with no versioning. + p, err := goose.NewProvider(database.DialectSQLite3, db, fsys, + goose.WithVerbose(testing.Verbose()), + goose.WithDisabledVersioning(true), // Provider with no versioning. ) check.NoError(t, err) check.Number(t, len(p.ListSources()), 2) @@ -552,8 +552,8 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_now_allowed", func(t *testing.T) { db := newDB(t) - p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), - provider.WithAllowedMissing(false), + p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), + goose.WithAllowedMissing(false), ) check.NoError(t, err) @@ -607,8 +607,8 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_allowed", func(t *testing.T) { db := newDB(t) - p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), - provider.WithAllowedMissing(true), + p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), + goose.WithAllowedMissing(true), ) check.NoError(t, err) @@ -640,7 +640,7 @@ func TestAllowMissing(t *testing.T) { check.Bool(t, upResult != nil, true) check.Number(t, upResult.Source.Version, 6) - count, err := getGooseVersionCount(db, provider.DefaultTablename) + count, err := getGooseVersionCount(db, goose.DefaultTablename) check.NoError(t, err) check.Number(t, count, 6) current, err := p.GetDBVersion(ctx) @@ -676,7 +676,7 @@ func TestAllowMissing(t *testing.T) { testDownAndVersion(1, 1) _, err = p.Down(ctx) check.HasError(t, err) - check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true) + check.Bool(t, errors.Is(err, goose.ErrNoNextVersion), true) }) } @@ -691,6 +691,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) { } func TestGoOnly(t *testing.T) { + t.Cleanup(goose.ResetGlobalMigrations) // Not parallel because each subtest modifies global state. countUser := func(db *sql.DB) int { @@ -703,99 +704,109 @@ func TestGoOnly(t *testing.T) { t.Run("with_tx", func(t *testing.T) { ctx := context.Background() - register := []*provider.MigrationCopy{ - { - Version: 1, Source: "00001_users_table.go", Registered: true, - UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), - DownFnContext: newTxFn("DROP TABLE users"), - }, + register := []*goose.Migration{ + goose.NewGoMigration( + 1, + &goose.GoFunc{RunTx: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)")}, + &goose.GoFunc{RunTx: newTxFn("DROP TABLE users")}, + ), } - err := provider.SetGlobalGoMigrations(register) + err := goose.SetGlobalMigrations(register...) check.NoError(t, err) - t.Cleanup(provider.ResetGlobalGoMigrations) + t.Cleanup(goose.ResetGlobalMigrations) db := newDB(t) - p, err := provider.NewProvider(database.DialectSQLite3, db, nil, - provider.WithGoMigration( + register = []*goose.Migration{ + goose.NewGoMigration( 2, - &provider.GoMigrationFunc{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, - &provider.GoMigrationFunc{Run: newTxFn("DELETE FROM users")}, + &goose.GoFunc{RunTx: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &goose.GoFunc{RunTx: newTxFn("DELETE FROM users")}, ), + } + p, err := goose.NewProvider(database.DialectSQLite3, db, nil, + goose.WithGoMigrations(register...), ) check.NoError(t, err) sources := p.ListSources() check.Number(t, len(p.ListSources()), 2) - assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) - assertSource(t, sources[1], provider.TypeGo, "", 2) + assertSource(t, sources[0], goose.TypeGo, "", 1) + assertSource(t, sources[1], goose.TypeGo, "", 2) // Apply migration 1 res, err := p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) + assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false) check.Number(t, countUser(db), 0) check.Bool(t, tableExists(t, db, "users"), true) // Apply migration 2 res, err = p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false) + assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false) check.Number(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false) + assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false) check.Number(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) + assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false) // Check table does not exist check.Bool(t, tableExists(t, db, "users"), false) }) t.Run("with_db", func(t *testing.T) { ctx := context.Background() - register := []*provider.MigrationCopy{ - { - Version: 1, Source: "00001_users_table.go", Registered: true, - UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), - DownFnNoTxContext: newDBFn("DROP TABLE users"), - }, + register := []*goose.Migration{ + goose.NewGoMigration( + 1, + &goose.GoFunc{ + RunDB: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), + }, + &goose.GoFunc{ + RunDB: newDBFn("DROP TABLE users"), + }, + ), } - err := provider.SetGlobalGoMigrations(register) + err := goose.SetGlobalMigrations(register...) check.NoError(t, err) - t.Cleanup(provider.ResetGlobalGoMigrations) + t.Cleanup(goose.ResetGlobalMigrations) db := newDB(t) - p, err := provider.NewProvider(database.DialectSQLite3, db, nil, - provider.WithGoMigration( + register = []*goose.Migration{ + goose.NewGoMigration( 2, - &provider.GoMigrationFunc{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, - &provider.GoMigrationFunc{RunNoTx: newDBFn("DELETE FROM users")}, + &goose.GoFunc{RunDB: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &goose.GoFunc{RunDB: newDBFn("DELETE FROM users")}, ), + } + p, err := goose.NewProvider(database.DialectSQLite3, db, nil, + goose.WithGoMigrations(register...), ) check.NoError(t, err) sources := p.ListSources() check.Number(t, len(p.ListSources()), 2) - assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) - assertSource(t, sources[1], provider.TypeGo, "", 2) + assertSource(t, sources[0], goose.TypeGo, "", 1) + assertSource(t, sources[1], goose.TypeGo, "", 2) // Apply migration 1 res, err := p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) + assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false) check.Number(t, countUser(db), 0) check.Bool(t, tableExists(t, db, "users"), true) // Apply migration 2 res, err = p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false) + assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false) check.Number(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false) + assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false) check.Number(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) + assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false) // Check table does not exist check.Bool(t, tableExists(t, db, "users"), false) }) @@ -818,12 +829,12 @@ func TestLockModeAdvisorySession(t *testing.T) { check.NoError(t, err) t.Cleanup(cleanup) - newProvider := func() *provider.Provider { + newProvider := func() *goose.Provider { sessionLocker, err := lock.NewPostgresSessionLocker() check.NoError(t, err) - p, err := provider.NewProvider(database.DialectPostgres, db, os.DirFS("../../testdata/migrations"), - provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode. - provider.WithVerbose(testing.Verbose()), + p, err := goose.NewProvider(database.DialectPostgres, db, os.DirFS("testdata/migrations"), + goose.WithSessionLocker(sessionLocker), // Use advisory session lock mode. + goose.WithVerbose(testing.Verbose()), ) check.NoError(t, err) return p @@ -891,7 +902,7 @@ func TestLockModeAdvisorySession(t *testing.T) { for { result, err := provider1.UpByOne(context.Background()) if err != nil { - if errors.Is(err, provider.ErrNoNextVersion) { + if errors.Is(err, goose.ErrNoNextVersion) { return nil } return err @@ -907,7 +918,7 @@ func TestLockModeAdvisorySession(t *testing.T) { for { result, err := provider2.UpByOne(context.Background()) if err != nil { - if errors.Is(err, provider.ErrNoNextVersion) { + if errors.Is(err, goose.ErrNoNextVersion) { return nil } return err @@ -993,7 +1004,7 @@ func TestLockModeAdvisorySession(t *testing.T) { for { result, err := provider1.Down(context.Background()) if err != nil { - if errors.Is(err, provider.ErrNoNextVersion) { + if errors.Is(err, goose.ErrNoNextVersion) { return nil } return err @@ -1009,7 +1020,7 @@ func TestLockModeAdvisorySession(t *testing.T) { for { result, err := provider2.Down(context.Background()) if err != nil { - if errors.Is(err, provider.ErrNoNextVersion) { + if errors.Is(err, goose.ErrNoNextVersion) { return nil } return err @@ -1068,14 +1079,14 @@ func randomAlphaNumeric(length int) string { return string(b) } -func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider.Provider, *sql.DB) { +func newProviderWithDB(t *testing.T, opts ...goose.ProviderOption) (*goose.Provider, *sql.DB) { t.Helper() db := newDB(t) opts = append( opts, - provider.WithVerbose(testing.Verbose()), + goose.WithVerbose(testing.Verbose()), ) - p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), opts...) + p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), opts...) check.NoError(t, err) return p, db } @@ -1118,14 +1129,14 @@ func getTableNames(db *sql.DB) ([]string, error) { return tables, nil } -func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) { +func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source goose.Source, appliedIsZero bool) { t.Helper() check.Equal(t, got.State, state) check.Equal(t, got.Source, source) check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) } -func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) { +func assertResult(t *testing.T, got *goose.MigrationResult, source goose.Source, direction string, isEmpty bool) { t.Helper() check.Bool(t, got != nil, true) check.Equal(t, got.Source, source) @@ -1135,21 +1146,15 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S check.Bool(t, got.Duration > 0, true) } -func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) { +func assertSource(t *testing.T, got goose.Source, typ goose.MigrationType, name string, version int64) { t.Helper() check.Equal(t, got.Type, typ) check.Equal(t, got.Path, name) check.Equal(t, got.Version, version) - switch got.Type { - case provider.TypeGo: - check.Equal(t, got.Type.String(), "go") - case provider.TypeSQL: - check.Equal(t, got.Type.String(), "sql") - } } -func newSource(t provider.MigrationType, fullpath string, version int64) provider.Source { - return provider.Source{ +func newSource(t goose.MigrationType, fullpath string, version int64) goose.Source { + return goose.Source{ Type: t, Path: fullpath, Version: version, diff --git a/provider_test.go b/provider_test.go new file mode 100644 index 000000000..c1e84b17d --- /dev/null +++ b/provider_test.go @@ -0,0 +1,79 @@ +package goose_test + +import ( + "database/sql" + "errors" + "io/fs" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/database" + "github.com/pressly/goose/v3/internal/check" + _ "modernc.org/sqlite" +) + +func TestProvider(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + t.Run("empty", func(t *testing.T) { + _, err := goose.NewProvider(database.DialectSQLite3, db, fstest.MapFS{}) + check.HasError(t, err) + check.Bool(t, errors.Is(err, goose.ErrNoMigrations), true) + }) + + mapFS := fstest.MapFS{ + "migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)}, + "migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + p, err := goose.NewProvider(database.DialectSQLite3, db, fsys) + check.NoError(t, err) + sources := p.ListSources() + check.Equal(t, len(sources), 2) + check.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1)) + check.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2)) + +} + +var ( + migration1 = ` +-- +goose Up +CREATE TABLE foo (id INTEGER PRIMARY KEY); +-- +goose Down +DROP TABLE foo; +` + migration2 = ` +-- +goose Up +ALTER TABLE foo ADD COLUMN name TEXT; +-- +goose Down +ALTER TABLE foo DROP COLUMN name; +` + migration3 = ` +-- +goose Up +CREATE TABLE bar ( + id INTEGER PRIMARY KEY, + description TEXT +); +-- +goose Down +DROP TABLE bar; +` + migration4 = ` +-- +goose Up +-- Rename the 'foo' table to 'my_foo' +ALTER TABLE foo RENAME TO my_foo; + +-- Add a new column 'timestamp' to 'my_foo' +ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; + +-- +goose Down +-- Remove the 'timestamp' column from 'my_foo' +ALTER TABLE my_foo DROP COLUMN timestamp; + +-- Rename the 'my_foo' table back to 'foo' +ALTER TABLE my_foo RENAME TO foo; +` +) diff --git a/register.go b/register.go index 89bd4c7a5..23a4f8353 100644 --- a/register.go +++ b/register.go @@ -66,7 +66,7 @@ func register(filename string, useTx bool, up, down *GoFunc) error { // We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but // we know based on the register function what the user is requesting. m.UseTx = useTx - registeredGoMigrations[v] = &m + registeredGoMigrations[v] = m return nil } diff --git a/internal/provider/testdata/no-versioning/migrations/00001_a.sql b/testdata/no-versioning/migrations/00001_a.sql similarity index 100% rename from internal/provider/testdata/no-versioning/migrations/00001_a.sql rename to testdata/no-versioning/migrations/00001_a.sql diff --git a/internal/provider/testdata/no-versioning/migrations/00002_b.sql b/testdata/no-versioning/migrations/00002_b.sql similarity index 100% rename from internal/provider/testdata/no-versioning/migrations/00002_b.sql rename to testdata/no-versioning/migrations/00002_b.sql diff --git a/internal/provider/testdata/no-versioning/migrations/00003_c.sql b/testdata/no-versioning/migrations/00003_c.sql similarity index 100% rename from internal/provider/testdata/no-versioning/migrations/00003_c.sql rename to testdata/no-versioning/migrations/00003_c.sql diff --git a/internal/provider/testdata/no-versioning/seed/00001_a.sql b/testdata/no-versioning/seed/00001_a.sql similarity index 100% rename from internal/provider/testdata/no-versioning/seed/00001_a.sql rename to testdata/no-versioning/seed/00001_a.sql diff --git a/internal/provider/testdata/no-versioning/seed/00002_b.sql b/testdata/no-versioning/seed/00002_b.sql similarity index 100% rename from internal/provider/testdata/no-versioning/seed/00002_b.sql rename to testdata/no-versioning/seed/00002_b.sql diff --git a/types.go b/types.go index f0008d09e..616c7d808 100644 --- a/types.go +++ b/types.go @@ -1,5 +1,7 @@ package goose +import "time" + // MigrationType is the type of migration. type MigrationType string @@ -8,10 +10,46 @@ const ( TypeSQL MigrationType = "sql" ) -func (t MigrationType) String() string { - // This should never happen. - if t == "" { - return "unknown migration type" - } - return string(t) +// Source represents a single migration source. +// +// The Path field may be empty if the migration was registered manually. This is typically the case +// for Go migrations registered using the [WithGoMigration] option. +type Source struct { + Type MigrationType + Path string + Version int64 +} + +// MigrationResult is the result of a single migration operation. +type MigrationResult struct { + Source Source + Duration time.Duration + Direction string + // Empty indicates no action was taken during the migration, but it was still versioned. For + // SQL, it means no statements; for Go, it's a nil function. + Empty bool + // Error is only set if the migration failed. + Error error +} + +// State represents the state of a migration. +type State string + +const ( + // StatePending is a migration that exists on the filesystem, but not in the database. + StatePending State = "pending" + // StateApplied is a migration that has been applied to the database and exists on the + // filesystem. + StateApplied State = "applied" + + // TODO(mf): we could also add a third state for untracked migrations. This would be useful for + // migrations that were manually applied to the database, but not versioned. Or the Source was + // deleted, but the migration still exists in the database. StateUntracked State = "untracked" +) + +// MigrationStatus represents the status of a single migration. +type MigrationStatus struct { + Source Source + State State + AppliedAt time.Time } From 613a3ee6f3f6593c240b5cde7b10ba37f7f256d3 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 7 Nov 2023 07:58:02 -0500 Subject: [PATCH 02/13] wip --- provider_migrate.go | 26 +++++++++++++------------- provider_run.go | 6 +++++- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/provider_migrate.go b/provider_migrate.go index 1386f7406..aa621615e 100644 --- a/provider_migrate.go +++ b/provider_migrate.go @@ -8,21 +8,23 @@ import ( "github.com/pressly/goose/v3/database" ) -func (m *Migration) useTx(direction bool) bool { +func (m *Migration) useTx(direction bool) (bool, error) { switch m.Type { case TypeGo: - if direction && m.goUp.Mode == TransactionEnabled { - return true + if m.goUp.Mode == 0 || m.goDown.Mode == 0 { + return false, fmt.Errorf("go migrations must have a mode set") } - if !direction && m.goDown.Mode == TransactionEnabled { - return true + if direction { + return m.goUp.Mode == TransactionEnabled, nil } - return false + return m.goDown.Mode == TransactionEnabled, nil case TypeSQL: - return m.sql.UseTx + if !m.sql.Parsed { + return false, fmt.Errorf("sql migrations must be parsed") + } + return m.sql.UseTx, nil } - // This should never happen. - panic(fmt.Sprintf("invalid migration type: %q", m.Type)) + return false, fmt.Errorf("invalid migration type: %q", m.Type) } func (m *Migration) isEmpty(direction bool) bool { @@ -38,8 +40,7 @@ func (m *Migration) isEmpty(direction bool) bool { } return len(m.sql.Down) == 0 } - // This should never happen. - panic(fmt.Sprintf("invalid migration type: %q", m.Type)) + return true } func (m *Migration) apply(ctx context.Context, db database.DBTxConn, direction bool) error { @@ -52,8 +53,7 @@ func (m *Migration) apply(ctx context.Context, db database.DBTxConn, direction b } return runSQL(ctx, db, m.sql.Down) } - // This should never happen. - panic(fmt.Sprintf("invalid migration type: %q", m.Type)) + return fmt.Errorf("invalid migration type: %q", m.Type) } func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { diff --git a/provider_run.go b/provider_run.go index 10cc63364..2b07e0d51 100644 --- a/provider_run.go +++ b/provider_run.go @@ -191,7 +191,11 @@ func (p *Provider) runIndividually( m *Migration, direction bool, ) error { - if m.useTx(direction) { + useTx, err := m.useTx(direction) + if err != nil { + return err + } + if useTx { return beginTx(ctx, conn, func(tx *sql.Tx) error { if err := m.apply(ctx, tx, direction); err != nil { return err From 228b6e65252c1bfcaaa5b5b031631cb0abc42b7d Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 7 Nov 2023 08:01:54 -0500 Subject: [PATCH 03/13] wip --- provider_migrate.go | 6 +++--- provider_run.go | 13 +++++-------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/provider_migrate.go b/provider_migrate.go index aa621615e..016809968 100644 --- a/provider_migrate.go +++ b/provider_migrate.go @@ -8,7 +8,7 @@ import ( "github.com/pressly/goose/v3/database" ) -func (m *Migration) useTx(direction bool) (bool, error) { +func useTx(m *Migration, direction bool) (bool, error) { switch m.Type { case TypeGo: if m.goUp.Mode == 0 || m.goDown.Mode == 0 { @@ -27,7 +27,7 @@ func (m *Migration) useTx(direction bool) (bool, error) { return false, fmt.Errorf("invalid migration type: %q", m.Type) } -func (m *Migration) isEmpty(direction bool) bool { +func isEmpty(m *Migration, direction bool) bool { switch m.Type { case TypeGo: if direction { @@ -43,7 +43,7 @@ func (m *Migration) isEmpty(direction bool) bool { return true } -func (m *Migration) apply(ctx context.Context, db database.DBTxConn, direction bool) error { +func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { switch m.Type { case TypeGo: return runGo(ctx, db, m, direction) diff --git a/provider_run.go b/provider_run.go index 2b07e0d51..a64b3aa40 100644 --- a/provider_run.go +++ b/provider_run.go @@ -164,7 +164,7 @@ func (p *Provider) runMigrations( Version: m.Version, }, Direction: direction.String(), - Empty: m.isEmpty(direction.ToBool()), + Empty: isEmpty(m, direction.ToBool()), } start := time.Now() if err := p.runIndividually(ctx, conn, m, direction.ToBool()); err != nil { @@ -191,13 +191,13 @@ func (p *Provider) runIndividually( m *Migration, direction bool, ) error { - useTx, err := m.useTx(direction) + useTx, err := useTx(m, direction) if err != nil { return err } if useTx { return beginTx(ctx, conn, func(tx *sql.Tx) error { - if err := m.apply(ctx, tx, direction); err != nil { + if err := runMigration(ctx, tx, m, direction); err != nil { return err } return p.maybeInsertOrDelete(ctx, tx, m.Version, direction) @@ -211,19 +211,16 @@ func (p *Provider) runIndividually( // acquire a connection from the pool. // // TODO(mf): we can detect this scenario and return a more helpful error message. - if err := m.apply(ctx, p.db, direction); err != nil { + if err := runMigration(ctx, p.db, m, direction); err != nil { return err } return p.maybeInsertOrDelete(ctx, p.db, m.Version, direction) case TypeSQL: - if err := m.apply(ctx, conn, direction); err != nil { + if err := runMigration(ctx, conn, m, direction); err != nil { return err } return p.maybeInsertOrDelete(ctx, conn, m.Version, direction) } - // - // This should never happen!! - // return fmt.Errorf("failed to run individual migration: neither sql or go: %v", m) } From ee858d95371be3c6cc10ecd09595aa2b289f18ac Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 7 Nov 2023 08:16:32 -0500 Subject: [PATCH 04/13] wip --- provider_migrate.go | 26 ++++++++++++++++++++------ provider_run.go | 3 ++- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/provider_migrate.go b/provider_migrate.go index 016809968..b446be701 100644 --- a/provider_migrate.go +++ b/provider_migrate.go @@ -8,6 +8,8 @@ import ( "github.com/pressly/goose/v3/database" ) +// useTx is a helper function that returns true if the migration should be run in a transaction. It +// must only be called after the migration has been parsed and initialized. func useTx(m *Migration, direction bool) (bool, error) { switch m.Type { case TypeGo: @@ -24,9 +26,11 @@ func useTx(m *Migration, direction bool) (bool, error) { } return m.sql.UseTx, nil } - return false, fmt.Errorf("invalid migration type: %q", m.Type) + return false, fmt.Errorf("use tx: invalid migration type: %q", m.Type) } +// isEmpty is a helper function that returns true if the migration has no functions or no statements +// to execute. It must only be called after the migration has been parsed and initialized. func isEmpty(m *Migration, direction bool) bool { switch m.Type { case TypeGo: @@ -43,15 +47,14 @@ func isEmpty(m *Migration, direction bool) bool { return true } +// runMigration is a helper function that runs the migration in the given direction. It must only be +// called after the migration has been parsed and initialized. func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { switch m.Type { case TypeGo: return runGo(ctx, db, m, direction) case TypeSQL: - if direction { - return runSQL(ctx, db, m.sql.Up) - } - return runSQL(ctx, db, m.sql.Down) + return runSQL(ctx, db, m, direction) } return fmt.Errorf("invalid migration type: %q", m.Type) } @@ -80,7 +83,18 @@ func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bo return fmt.Errorf("invalid database connection type: %T", db) } -func runSQL(ctx context.Context, db database.DBTxConn, statements []string) error { +// runSQL is a helper function that runs the given SQL statements in the given direction. It must +// only be called after the migration has been parsed. +func runSQL(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { + if !m.sql.Parsed { + return fmt.Errorf("sql migrations must be parsed") + } + var statements []string + if direction { + statements = m.sql.Up + } else { + statements = m.sql.Down + } for _, stmt := range statements { if _, err := db.ExecContext(ctx, stmt); err != nil { return err diff --git a/provider_run.go b/provider_run.go index a64b3aa40..859bb68af 100644 --- a/provider_run.go +++ b/provider_run.go @@ -210,7 +210,8 @@ func (p *Provider) runIndividually( // acquired on the database. In this case, the migration will block forever unable to // acquire a connection from the pool. // - // TODO(mf): we can detect this scenario and return a more helpful error message. + // For now, we guard against this scenario by checking the max open connections and + // returning an error in the prepareMigration function. if err := runMigration(ctx, p.db, m, direction); err != nil { return err } From e91d5157d5c2c1936aaa9bf69f241ed6c5d1b8a2 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 7 Nov 2023 08:19:02 -0500 Subject: [PATCH 05/13] wip --- provider_migrate.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/provider_migrate.go b/provider_migrate.go index b446be701..caa78f227 100644 --- a/provider_migrate.go +++ b/provider_migrate.go @@ -59,6 +59,8 @@ func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direc return fmt.Errorf("invalid migration type: %q", m.Type) } +// runGo is a helper function that runs the given Go functions in the given direction. It must only +// be called after the migration has been initialized. func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { switch db := db.(type) { case *sql.Conn: From 688b053a268d6e174c14d5107ac17b4edbba24b7 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 7 Nov 2023 08:22:43 -0500 Subject: [PATCH 06/13] wip --- provider.go | 199 ++++++++++++++++++++++++++++++++++++++++++ provider_impl.go | 206 -------------------------------------------- provider_migrate.go | 106 ----------------------- provider_run.go | 97 +++++++++++++++++++++ 4 files changed, 296 insertions(+), 312 deletions(-) delete mode 100644 provider_impl.go delete mode 100644 provider_migrate.go diff --git a/provider.go b/provider.go index d67a01bfb..dd3c21c62 100644 --- a/provider.go +++ b/provider.go @@ -7,9 +7,12 @@ import ( "fmt" "io/fs" "math" + "sort" "sync" "github.com/pressly/goose/v3/database" + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" ) // Provider is a goose migration goose. @@ -245,3 +248,199 @@ func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResul } return p.down(ctx, false, version) } + +// *** Internal methods *** + +func (p *Provider) up( + ctx context.Context, + upByOne bool, + version int64, +) (_ []*MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + if len(p.migrations) == 0 { + return nil, nil + } + var apply []*Migration + if p.cfg.disableVersioning { + apply = p.migrations + } else { + // optimize(mf): Listing all migrations from the database isn't great. This is only required + // to support the allow missing (out-of-order) feature. For users that don't use this + // feature, we could just query the database for the current max version and then apply + // migrations greater than that version. + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if len(dbMigrations) == 0 { + return nil, errMissingZeroVersion + } + apply, err = p.resolveUpMigrations(dbMigrations, version) + if err != nil { + return nil, err + } + } + // feat(mf): this is where can (optionally) group multiple migrations to be run in a single + // transaction. The default is to apply each migration sequentially on its own. + // https://github.com/pressly/goose/issues/222 + // + // Careful, we can't use a single transaction for all migrations because some may have to be run + // in their own transaction. + return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne) +} + +func (p *Provider) down( + ctx context.Context, + downByOne bool, + version int64, +) (_ []*MigrationResult, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + if len(p.migrations) == 0 { + return nil, nil + } + if p.cfg.disableVersioning { + downMigrations := p.migrations + if downByOne { + last := p.migrations[len(p.migrations)-1] + downMigrations = []*Migration{last} + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + } + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if len(dbMigrations) == 0 { + return nil, errMissingZeroVersion + } + if dbMigrations[0].Version == 0 { + return nil, nil + } + var downMigrations []*Migration + for _, dbMigration := range dbMigrations { + if dbMigration.Version <= version { + break + } + m, err := p.getMigration(dbMigration.Version) + if err != nil { + return nil, err + } + downMigrations = append(downMigrations, m) + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) +} + +func (p *Provider) apply( + ctx context.Context, + version int64, + direction bool, +) (_ *MigrationResult, retErr error) { + m, err := p.getMigration(version) + if err != nil { + return nil, err + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + result, err := p.store.GetMigration(ctx, conn, version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + // If the migration has already been applied, return an error, unless the migration is being + // applied in the opposite direction. In that case, we allow the migration to be applied again. + if result != nil && direction { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + + d := sqlparser.DirectionDown + if direction { + d = sqlparser.DirectionUp + } + results, err := p.runMigrations(ctx, conn, []*Migration{m}, d, true) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + return results[0], nil +} + +func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to + // support limiting the set. + + status := make([]*MigrationStatus, 0, len(p.migrations)) + for _, m := range p.migrations { + migrationStatus := &MigrationStatus{ + Source: Source{ + Type: m.Type, + Path: m.Source, + Version: m.Version, + }, + State: StatePending, + } + dbResult, err := p.store.GetMigration(ctx, conn, m.Version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + if dbResult != nil { + migrationStatus.State = StateApplied + migrationStatus.AppliedAt = dbResult.Timestamp + } + status = append(status, migrationStatus) + } + + return status, nil +} + +func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return 0, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + res, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return 0, err + } + if len(res) == 0 { + return 0, nil + } + sort.Slice(res, func(i, j int) bool { + return res[i].Version > res[j].Version + }) + return res[0].Version, nil +} diff --git a/provider_impl.go b/provider_impl.go deleted file mode 100644 index 3e838d304..000000000 --- a/provider_impl.go +++ /dev/null @@ -1,206 +0,0 @@ -package goose - -import ( - "context" - "database/sql" - "errors" - "fmt" - "sort" - - "github.com/pressly/goose/v3/internal/sqlparser" - "go.uber.org/multierr" -) - -func (p *Provider) up( - ctx context.Context, - upByOne bool, - version int64, -) (_ []*MigrationResult, retErr error) { - if version < 1 { - return nil, errors.New("version must be greater than zero") - } - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - if len(p.migrations) == 0 { - return nil, nil - } - var apply []*Migration - if p.cfg.disableVersioning { - apply = p.migrations - } else { - // optimize(mf): Listing all migrations from the database isn't great. This is only required - // to support the allow missing (out-of-order) feature. For users that don't use this - // feature, we could just query the database for the current max version and then apply - // migrations greater than that version. - dbMigrations, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return nil, err - } - if len(dbMigrations) == 0 { - return nil, errMissingZeroVersion - } - apply, err = p.resolveUpMigrations(dbMigrations, version) - if err != nil { - return nil, err - } - } - // feat(mf): this is where can (optionally) group multiple migrations to be run in a single - // transaction. The default is to apply each migration sequentially on its own. - // https://github.com/pressly/goose/issues/222 - // - // Careful, we can't use a single transaction for all migrations because some may have to be run - // in their own transaction. - return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne) -} - -func (p *Provider) down( - ctx context.Context, - downByOne bool, - version int64, -) (_ []*MigrationResult, retErr error) { - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - if len(p.migrations) == 0 { - return nil, nil - } - if p.cfg.disableVersioning { - downMigrations := p.migrations - if downByOne { - last := p.migrations[len(p.migrations)-1] - downMigrations = []*Migration{last} - } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) - } - dbMigrations, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return nil, err - } - if len(dbMigrations) == 0 { - return nil, errMissingZeroVersion - } - if dbMigrations[0].Version == 0 { - return nil, nil - } - var downMigrations []*Migration - for _, dbMigration := range dbMigrations { - if dbMigration.Version <= version { - break - } - m, err := p.getMigration(dbMigration.Version) - if err != nil { - return nil, err - } - downMigrations = append(downMigrations, m) - } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) -} - -func (p *Provider) apply( - ctx context.Context, - version int64, - direction bool, -) (_ *MigrationResult, retErr error) { - m, err := p.getMigration(version) - if err != nil { - return nil, err - } - - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - result, err := p.store.GetMigration(ctx, conn, version) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - // If the migration has already been applied, return an error, unless the migration is being - // applied in the opposite direction. In that case, we allow the migration to be applied again. - if result != nil && direction { - return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) - } - - d := sqlparser.DirectionDown - if direction { - d = sqlparser.DirectionUp - } - results, err := p.runMigrations(ctx, conn, []*Migration{m}, d, true) - if err != nil { - return nil, err - } - if len(results) == 0 { - return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) - } - return results[0], nil -} - -func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to - // support limiting the set. - - status := make([]*MigrationStatus, 0, len(p.migrations)) - for _, m := range p.migrations { - migrationStatus := &MigrationStatus{ - Source: Source{ - Type: m.Type, - Path: m.Source, - Version: m.Version, - }, - State: StatePending, - } - dbResult, err := p.store.GetMigration(ctx, conn, m.Version) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - if dbResult != nil { - migrationStatus.State = StateApplied - migrationStatus.AppliedAt = dbResult.Timestamp - } - status = append(status, migrationStatus) - } - - return status, nil -} - -func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return 0, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - res, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return 0, err - } - if len(res) == 0 { - return 0, nil - } - sort.Slice(res, func(i, j int) bool { - return res[i].Version > res[j].Version - }) - return res[0].Version, nil -} diff --git a/provider_migrate.go b/provider_migrate.go deleted file mode 100644 index caa78f227..000000000 --- a/provider_migrate.go +++ /dev/null @@ -1,106 +0,0 @@ -package goose - -import ( - "context" - "database/sql" - "fmt" - - "github.com/pressly/goose/v3/database" -) - -// useTx is a helper function that returns true if the migration should be run in a transaction. It -// must only be called after the migration has been parsed and initialized. -func useTx(m *Migration, direction bool) (bool, error) { - switch m.Type { - case TypeGo: - if m.goUp.Mode == 0 || m.goDown.Mode == 0 { - return false, fmt.Errorf("go migrations must have a mode set") - } - if direction { - return m.goUp.Mode == TransactionEnabled, nil - } - return m.goDown.Mode == TransactionEnabled, nil - case TypeSQL: - if !m.sql.Parsed { - return false, fmt.Errorf("sql migrations must be parsed") - } - return m.sql.UseTx, nil - } - return false, fmt.Errorf("use tx: invalid migration type: %q", m.Type) -} - -// isEmpty is a helper function that returns true if the migration has no functions or no statements -// to execute. It must only be called after the migration has been parsed and initialized. -func isEmpty(m *Migration, direction bool) bool { - switch m.Type { - case TypeGo: - if direction { - return m.goUp.RunTx == nil && m.goUp.RunDB == nil - } - return m.goDown.RunTx == nil && m.goDown.RunDB == nil - case TypeSQL: - if direction { - return len(m.sql.Up) == 0 - } - return len(m.sql.Down) == 0 - } - return true -} - -// runMigration is a helper function that runs the migration in the given direction. It must only be -// called after the migration has been parsed and initialized. -func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { - switch m.Type { - case TypeGo: - return runGo(ctx, db, m, direction) - case TypeSQL: - return runSQL(ctx, db, m, direction) - } - return fmt.Errorf("invalid migration type: %q", m.Type) -} - -// runGo is a helper function that runs the given Go functions in the given direction. It must only -// be called after the migration has been initialized. -func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { - switch db := db.(type) { - case *sql.Conn: - return fmt.Errorf("go migrations are not supported with *sql.Conn") - case *sql.DB: - if direction && m.goUp.RunDB != nil { - return m.goUp.RunDB(ctx, db) - } - if !direction && m.goDown.RunDB != nil { - return m.goDown.RunDB(ctx, db) - } - return nil - case *sql.Tx: - if direction && m.goUp.RunTx != nil { - return m.goUp.RunTx(ctx, db) - } - if !direction && m.goDown.RunTx != nil { - return m.goDown.RunTx(ctx, db) - } - return nil - } - return fmt.Errorf("invalid database connection type: %T", db) -} - -// runSQL is a helper function that runs the given SQL statements in the given direction. It must -// only be called after the migration has been parsed. -func runSQL(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { - if !m.sql.Parsed { - return fmt.Errorf("sql migrations must be parsed") - } - var statements []string - if direction { - statements = m.sql.Up - } else { - statements = m.sql.Down - } - for _, stmt := range statements { - if _, err := db.ExecContext(ctx, stmt); err != nil { - return err - } - } - return nil -} diff --git a/provider_run.go b/provider_run.go index 859bb68af..d50d19950 100644 --- a/provider_run.go +++ b/provider_run.go @@ -356,3 +356,100 @@ func (p *Provider) getMigration(version int64) (*Migration, error) { } return nil, ErrVersionNotFound } + +// useTx is a helper function that returns true if the migration should be run in a transaction. It +// must only be called after the migration has been parsed and initialized. +func useTx(m *Migration, direction bool) (bool, error) { + switch m.Type { + case TypeGo: + if m.goUp.Mode == 0 || m.goDown.Mode == 0 { + return false, fmt.Errorf("go migrations must have a mode set") + } + if direction { + return m.goUp.Mode == TransactionEnabled, nil + } + return m.goDown.Mode == TransactionEnabled, nil + case TypeSQL: + if !m.sql.Parsed { + return false, fmt.Errorf("sql migrations must be parsed") + } + return m.sql.UseTx, nil + } + return false, fmt.Errorf("use tx: invalid migration type: %q", m.Type) +} + +// isEmpty is a helper function that returns true if the migration has no functions or no statements +// to execute. It must only be called after the migration has been parsed and initialized. +func isEmpty(m *Migration, direction bool) bool { + switch m.Type { + case TypeGo: + if direction { + return m.goUp.RunTx == nil && m.goUp.RunDB == nil + } + return m.goDown.RunTx == nil && m.goDown.RunDB == nil + case TypeSQL: + if direction { + return len(m.sql.Up) == 0 + } + return len(m.sql.Down) == 0 + } + return true +} + +// runMigration is a helper function that runs the migration in the given direction. It must only be +// called after the migration has been parsed and initialized. +func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { + switch m.Type { + case TypeGo: + return runGo(ctx, db, m, direction) + case TypeSQL: + return runSQL(ctx, db, m, direction) + } + return fmt.Errorf("invalid migration type: %q", m.Type) +} + +// runGo is a helper function that runs the given Go functions in the given direction. It must only +// be called after the migration has been initialized. +func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { + switch db := db.(type) { + case *sql.Conn: + return fmt.Errorf("go migrations are not supported with *sql.Conn") + case *sql.DB: + if direction && m.goUp.RunDB != nil { + return m.goUp.RunDB(ctx, db) + } + if !direction && m.goDown.RunDB != nil { + return m.goDown.RunDB(ctx, db) + } + return nil + case *sql.Tx: + if direction && m.goUp.RunTx != nil { + return m.goUp.RunTx(ctx, db) + } + if !direction && m.goDown.RunTx != nil { + return m.goDown.RunTx(ctx, db) + } + return nil + } + return fmt.Errorf("invalid database connection type: %T", db) +} + +// runSQL is a helper function that runs the given SQL statements in the given direction. It must +// only be called after the migration has been parsed. +func runSQL(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error { + if !m.sql.Parsed { + return fmt.Errorf("sql migrations must be parsed") + } + var statements []string + if direction { + statements = m.sql.Up + } else { + statements = m.sql.Down + } + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} From 7ac011e6c3ed9d02220041ab052c387be8288688 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 7 Nov 2023 08:45:48 -0500 Subject: [PATCH 07/13] wip --- provider_options_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/provider_options_test.go b/provider_options_test.go index 0575cf417..a5df0fa5e 100644 --- a/provider_options_test.go +++ b/provider_options_test.go @@ -32,9 +32,6 @@ func TestNewProvider(t *testing.T) { // Nil db not allowed _, err = goose.NewProvider(database.DialectSQLite3, nil, fsys) check.HasError(t, err) - // Nil fsys not allowed - _, err = goose.NewProvider(database.DialectSQLite3, db, nil) - check.HasError(t, err) // Nil store not allowed _, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(nil)) check.HasError(t, err) From f66511b43d948e095c2837a886ae76168159e6a7 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Wed, 8 Nov 2023 08:29:04 -0500 Subject: [PATCH 08/13] checkpoint: cleanup --- database/dialect.go | 4 + database/store.go | 10 +- database/store_test.go | 2 +- provider.go | 209 +++++++++++++++++++++++------------------ provider_errors.go | 3 + provider_options.go | 6 +- provider_run.go | 13 ++- provider_run_test.go | 18 ++-- types.go | 4 +- 9 files changed, 159 insertions(+), 110 deletions(-) diff --git a/database/dialect.go b/database/dialect.go index 8767db01a..2e502f526 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "errors" "fmt" @@ -100,6 +101,9 @@ func (s *store) GetMigration( &result.Timestamp, &result.IsApplied, ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version) + } return nil, fmt.Errorf("failed to get migration %d: %w", version, err) } return &result, nil diff --git a/database/store.go b/database/store.go index fe496bb50..107a1eea9 100644 --- a/database/store.go +++ b/database/store.go @@ -2,9 +2,15 @@ package database import ( "context" + "errors" "time" ) +var ( + // ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found. + ErrVersionNotFound = errors.New("version not found") +) + // Store is an interface that defines methods for managing database migrations and versioning. By // defining a Store interface, we can support multiple databases with consistent functionality. // @@ -24,8 +30,8 @@ type Store interface { // Delete deletes a version id from the version table. Delete(ctx context.Context, db DBTxConn, version int64) error - // GetMigration retrieves a single migration by version id. This method may return the raw sql - // error if the query fails so the caller can assert for errors such as [sql.ErrNoRows]. + // GetMigration retrieves a single migration by version id. If the query succeeds, but the + // version is not found, this method must return [ErrVersionNotFound]. GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) // ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If diff --git a/database/store_test.go b/database/store_test.go index d8d59ba8f..b4dbf472c 100644 --- a/database/store_test.go +++ b/database/store_test.go @@ -205,7 +205,7 @@ func testStore( err = runConn(ctx, db, func(conn *sql.Conn) error { _, err := store.GetMigration(ctx, conn, 0) check.HasError(t, err) - check.Bool(t, errors.Is(err, sql.ErrNoRows), true) + check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true) return nil }) check.NoError(t, err) diff --git a/provider.go b/provider.go index dd3c21c62..d36b00e36 100644 --- a/provider.go +++ b/provider.go @@ -8,6 +8,8 @@ import ( "io/fs" "math" "sort" + "strconv" + "strings" "sync" "github.com/pressly/goose/v3/database" @@ -15,22 +17,23 @@ import ( "go.uber.org/multierr" ) -// Provider is a goose migration goose. +// Provider is a goose migration provider. type Provider struct { // mu protects all accesses to the provider and must be held when calling operations on the // database. mu sync.Mutex db *sql.DB - fsys fs.FS - cfg config store database.Store + fsys fs.FS + cfg config + // migrations are ordered by version in ascending order. migrations []*Migration } -// NewProvider returns a new goose goose. +// NewProvider returns a new goose provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For // example, if the database dialect is "postgres", the database/sql driver could be @@ -38,12 +41,12 @@ type Provider struct { // constant backed by a default [database.Store] implementation. For more advanced use cases, such // as using a custom table name or supplying a custom store implementation, see [WithStore]. // -// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to -// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem. +// fsys is the filesystem used to read migration files, but may be nil. Most users will want to use +// [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem. // However, it is possible to use a different "filesystem", such as [embed.FS] or filter out // migrations using [fs.Sub]. // -// See [ProviderOption] for more information on configuring the goose. +// See [ProviderOption] for more information on configuring the provider. // // Unless otherwise specified, all methods on Provider are safe for concurrent use. // @@ -71,7 +74,7 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi return nil, errors.New("dialect must not be empty") } if dialect != "" && cfg.store != nil { - return nil, errors.New("cannot set both dialect and custom store") + return nil, errors.New("dialect must be empty when using a custom store implementation") } var store database.Store if dialect != "" { @@ -98,29 +101,29 @@ func newProvider( ) (*Provider, error) { // Collect migrations from the filesystem and merge with registered migrations. // - // Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed - // lazily. - // - // TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to - // return an error if there are any SQL parsing errors. This adds a bit overhead to startup - // though, so we should make it optional. + // Note, we don't parse SQL migrations here. They are parsed lazily when required. + + // feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return + // an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so + // we should make it optional. filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions) if err != nil { return nil, err } versionToGoMigration := make(map[int64]*Migration) - // Add user-registered Go migrations. + // Add user-registered Go migrations from the provider. for version, m := range cfg.registered { versionToGoMigration[version] = m } - // Add init() functions. This is a bit ugly because we need to convert from the old Migration - // struct to the new goMigration struct. + // Add globally registered Go migrations. for version, m := range global { if _, ok := versionToGoMigration[version]; ok { - return nil, fmt.Errorf("global go migration with version %d already registered with provider", version) + return nil, fmt.Errorf("global go migration with version %d previously registered with provider", version) } versionToGoMigration[version] = m } + // At this point we have all registered unique Go migrations (if any). We need to merge them + // with SQL migrations from the filesystem. migrations, err := merge(filesystemSources, versionToGoMigration) if err != nil { return nil, err @@ -143,19 +146,20 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { return p.status(ctx) } -// GetDBVersion returns the max version from the database, regardless of the applied order. For -// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been -// applied, it returns 0. +// GetDBVersion returns the highest version recorded in the database, regardless of the order in +// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3), +// this method returns 4. If no migrations have been applied, it returns 0. func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { - return p.getDBVersion(ctx) + return p.getDBMaxVersion(ctx) } -// ListSources returns a list of all available migration sources the provider is aware of, sorted in -// ascending order by version. -func (p *Provider) ListSources() []Source { - sources := make([]Source, 0, len(p.migrations)) +// ListSources returns a list of all migration sources known to the provider, sorted in ascending +// order by version. The path field may be empty for manually registered migrations, such as Go +// migrations registered using the [WithGoMigrations] option. +func (p *Provider) ListSources() []*Source { + sources := make([]*Source, 0, len(p.migrations)) for _, m := range p.migrations { - sources = append(sources, Source{ + sources = append(sources, &Source{ Type: m.Type, Path: m.Source, Version: m.Version, @@ -169,31 +173,43 @@ func (p *Provider) Ping(ctx context.Context) error { return p.db.PingContext(ctx) } -// Close closes the database connection. +// Close closes the database connection initially supplied to the provider. func (p *Provider) Close() error { return p.db.Close() } -// ApplyVersion applies exactly one migration by version. If there is no source for the specified -// version, this method returns [ErrVersionNotFound]. If the migration has been applied already, -// this method returns [ErrAlreadyApplied]. +// ApplyVersion applies exactly one migration for the specified version. If there is no migration +// available for the specified version, this method returns [ErrVersionNotFound]. If the migration +// has already been applied, this method returns [ErrAlreadyApplied]. // -// When direction is true, the up migration is executed, and when direction is false, the down -// migration is executed. +// The direction parameter determines the migration direction: true for up migration and false for +// down migration. func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { - if version < 1 { - return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version) + res, err := p.apply(ctx, version, direction) + if err != nil { + return nil, err + } + // This should never happen, we must return exactly one result. + if len(res) != 1 { + versions := make([]string, 0, len(res)) + for _, r := range res { + versions = append(versions, strconv.FormatInt(r.Source.Version, 10)) + } + return nil, fmt.Errorf( + "unexpected number of migrations applied running apply, expecting exactly one result: %v", + strings.Join(versions, ","), + ) } - return p.apply(ctx, version, direction) + return res[0], nil } -// Up applies all [StatePending] migrations. If there are no new migrations to apply, this method -// returns empty list and nil error. +// Up applies all pending migrations. If there are no new migrations to apply, this method returns +// empty list and nil error. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { return p.up(ctx, false, math.MaxInt64) } -// UpByOne applies the next available migration. If there are no migrations to apply, this method +// UpByOne applies the next pending migration. If there is no next migration to apply, this method // returns [ErrNoNextVersion]. The returned list will always have exactly one migration result. func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { res, err := p.up(ctx, true, math.MaxInt64) @@ -203,27 +219,35 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { if len(res) == 0 { return nil, ErrNoNextVersion } - // This should never happen. We should always have exactly one result and test for this. - if len(res) > 1 { - return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res)) + // This should never happen, we must return exactly one result. + if len(res) != 1 { + versions := make([]string, 0, len(res)) + for _, r := range res { + versions = append(versions, strconv.FormatInt(r.Source.Version, 10)) + } + return nil, fmt.Errorf( + "unexpected number of migrations applied running up-by-one, expecting exactly one result: %v", + strings.Join(versions, ","), + ) } return res[0], nil } -// UpTo applies all available migrations up to, and including, the specified version. If there are -// no migrations to apply, this method returns empty list and nil error. +// UpTo applies all pending migrations up to, and including, the specified version. If there are no +// migrations to apply, this method returns empty list and nil error. // -// For instance, if there are three new migrations (9,10,11) and the current database version is 8 +// For example, if there are three new migrations (9,10,11) and the current database version is 8 // with a requested version of 10, only versions 9,10 will be applied. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { - if version < 1 { - return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version) - } return p.up(ctx, false, version) } -// Down rolls back the most recently applied migration. If there are no migrations to apply, this +// Down rolls back the most recently applied migration. If there are no migrations to rollback, this // method returns [ErrNoNextVersion]. +// +// Note, migrations are rolled back in the order they were applied. And not in the reverse order of +// the migration version. This only applies in scenarios where migrations are allowed to be applied +// out of order. func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { res, err := p.down(ctx, true, 0) if err != nil { @@ -232,16 +256,28 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { if len(res) == 0 { return nil, ErrNoNextVersion } - if len(res) > 1 { - return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res)) + // This should never happen, we must return exactly one result. + if len(res) != 1 { + versions := make([]string, 0, len(res)) + for _, r := range res { + versions = append(versions, strconv.FormatInt(r.Source.Version, 10)) + } + return nil, fmt.Errorf( + "unexpected number of migrations applied running down, expecting exactly one result: %v", + strings.Join(versions, ","), + ) } return res[0], nil } // DownTo rolls back all migrations down to, but not including, the specified version. // -// For instance, if the current database version is 11,10,9... and the requested version is 9, only +// For example, if the current database version is 11,10,9... and the requested version is 9, only // migrations 11, 10 will be rolled back. +// +// Note, migrations are rolled back in the order they were applied. And not in the reverse order of +// the migration version. This only applies in scenarios where migrations are allowed to be applied +// out of order. func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { if version < 0 { return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version) @@ -253,11 +289,11 @@ func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResul func (p *Provider) up( ctx context.Context, - upByOne bool, + byOne bool, version int64, ) (_ []*MigrationResult, retErr error) { if version < 1 { - return nil, errors.New("version must be greater than zero") + return nil, errInvalidVersion } conn, cleanup, err := p.initialize(ctx) if err != nil { @@ -266,11 +302,15 @@ func (p *Provider) up( defer func() { retErr = multierr.Append(retErr, cleanup()) }() + if len(p.migrations) == 0 { return nil, nil } var apply []*Migration if p.cfg.disableVersioning { + if byOne { + return nil, errors.New("up-by-one not supported when versioning is disabled") + } apply = p.migrations } else { // optimize(mf): Listing all migrations from the database isn't great. This is only required @@ -289,18 +329,12 @@ func (p *Provider) up( return nil, err } } - // feat(mf): this is where can (optionally) group multiple migrations to be run in a single - // transaction. The default is to apply each migration sequentially on its own. - // https://github.com/pressly/goose/issues/222 - // - // Careful, we can't use a single transaction for all migrations because some may have to be run - // in their own transaction. - return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne) + return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, byOne) } func (p *Provider) down( ctx context.Context, - downByOne bool, + byOne bool, version int64, ) (_ []*MigrationResult, retErr error) { conn, cleanup, err := p.initialize(ctx) @@ -310,16 +344,19 @@ func (p *Provider) down( defer func() { retErr = multierr.Append(retErr, cleanup()) }() + if len(p.migrations) == 0 { return nil, nil } if p.cfg.disableVersioning { - downMigrations := p.migrations - if downByOne { + var downMigrations []*Migration + if byOne { last := p.migrations[len(p.migrations)-1] downMigrations = []*Migration{last} + } else { + downMigrations = p.migrations } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, byOne) } dbMigrations, err := p.store.ListMigrations(ctx, conn) if err != nil { @@ -328,10 +365,11 @@ func (p *Provider) down( if len(dbMigrations) == 0 { return nil, errMissingZeroVersion } + // We never migrate the zero version down. if dbMigrations[0].Version == 0 { return nil, nil } - var downMigrations []*Migration + var apply []*Migration for _, dbMigration := range dbMigrations { if dbMigration.Version <= version { break @@ -340,21 +378,23 @@ func (p *Provider) down( if err != nil { return nil, err } - downMigrations = append(downMigrations, m) + apply = append(apply, m) } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + return p.runMigrations(ctx, conn, apply, sqlparser.DirectionDown, byOne) } func (p *Provider) apply( ctx context.Context, version int64, direction bool, -) (_ *MigrationResult, retErr error) { +) (_ []*MigrationResult, retErr error) { + if version < 1 { + return nil, errInvalidVersion + } m, err := p.getMigration(version) if err != nil { return nil, err } - conn, cleanup, err := p.initialize(ctx) if err != nil { return nil, err @@ -364,27 +404,19 @@ func (p *Provider) apply( }() result, err := p.store.GetMigration(ctx, conn, version) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, database.ErrVersionNotFound) { return nil, err } - // If the migration has already been applied, return an error, unless the migration is being - // applied in the opposite direction. In that case, we allow the migration to be applied again. + // If the migration has already been applied, return an error. But, if the migration is being + // rolled back, we allow the individual migration to be applied again. if result != nil && direction { return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) } - d := sqlparser.DirectionDown if direction { d = sqlparser.DirectionUp } - results, err := p.runMigrations(ctx, conn, []*Migration{m}, d, true) - if err != nil { - return nil, err - } - if len(results) == 0 { - return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) - } - return results[0], nil + return p.runMigrations(ctx, conn, []*Migration{m}, d, true) } func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { @@ -396,13 +428,10 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err retErr = multierr.Append(retErr, cleanup()) }() - // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to - // support limiting the set. - status := make([]*MigrationStatus, 0, len(p.migrations)) for _, m := range p.migrations { migrationStatus := &MigrationStatus{ - Source: Source{ + Source: &Source{ Type: m.Type, Path: m.Source, Version: m.Version, @@ -410,7 +439,7 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err State: StatePending, } dbResult, err := p.store.GetMigration(ctx, conn, m.Version) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, database.ErrVersionNotFound) { return nil, err } if dbResult != nil { @@ -423,7 +452,7 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err return status, nil } -func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { +func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) { conn, cleanup, err := p.initialize(ctx) if err != nil { return 0, err @@ -437,10 +466,12 @@ func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { return 0, err } if len(res) == 0 { - return 0, nil + return 0, errMissingZeroVersion } + // Sort in descending order. sort.Slice(res, func(i, j int) bool { return res[i].Version > res[j].Version }) + // Return the highest version. return res[0].Version, nil } diff --git a/provider_errors.go b/provider_errors.go index ecaf89801..300464121 100644 --- a/provider_errors.go +++ b/provider_errors.go @@ -15,6 +15,9 @@ var ( // ErrNoMigrations is returned by [NewProvider] when no migrations are found. ErrNoMigrations = errors.New("no migrations found") + + // errInvalidVersion is returned when a migration version is invalid. + errInvalidVersion = errors.New("version must be greater than 0") ) // PartialError is returned when a migration fails, but some migrations already got applied. diff --git a/provider_options.go b/provider_options.go index 03c6f7c62..53edfd340 100644 --- a/provider_options.go +++ b/provider_options.go @@ -10,7 +10,7 @@ import ( const ( // DefaultTablename is the default name of the database table used to track history of applied - // migrations. It can be overridden using the [WithTableName] option when creating a new goose. + // migrations. DefaultTablename = "goose_db_version" ) @@ -142,11 +142,11 @@ func WithAllowedMissing(b bool) ProviderOption { }) } -// WithDisabledVersioning disables versioning. Disabling versioning allows applying migrations +// WithDisableVersioning disables versioning. Disabling versioning allows applying migrations // without tracking the versions in the database schema table. Useful for tests, seeding a database // or running ad-hoc queries. By default, goose will track all versions in the database schema // table. -func WithDisabledVersioning(b bool) ProviderOption { +func WithDisableVersioning(b bool) ProviderOption { return configFunc(func(c *config) error { c.disableVersioning = b return nil diff --git a/provider_run.go b/provider_run.go index d50d19950..28fe5319b 100644 --- a/provider_run.go +++ b/provider_run.go @@ -154,11 +154,14 @@ func (p *Provider) runMigrations( // following issues for more details: // - https://github.com/pressly/goose/issues/485 // - https://github.com/pressly/goose/issues/222 + // + // Be careful, we can't use a single transaction for all migrations because some may be marked + // as not using a transaction. var results []*MigrationResult for _, m := range apply { current := &MigrationResult{ - Source: Source{ + Source: &Source{ Type: m.Type, Path: m.Source, Version: m.Version, @@ -275,18 +278,20 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err if err := l.SessionLock(ctx, conn); err != nil { return nil, nil, multierr.Append(err, cleanup()) } + // A lock was acquired, so we need to unlock the session when we're done. This is done by + // returning a cleanup function that unlocks the session and closes the connection. cleanup = func() error { p.mu.Unlock() // Use a detached context to unlock the session. This is because the context passed to // SessionLock may have been canceled, and we don't want to cancel the unlock. // - // TODO(mf): use [context.WithoutCancel] added in go1.21 + // TODO(mf): use context.WithoutCancel added in go1.21 detachedCtx := context.Background() return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) } } // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't - // need the version table because there is no versioning. + // need the version table because no versions are being recorded. if !p.cfg.disableVersioning { if err := p.ensureVersionTable(ctx, conn); err != nil { return nil, nil, multierr.Append(err, cleanup()) @@ -346,7 +351,7 @@ func checkMissingMigrations( return missing } -// getMigration returns the migration with the given version. If no migration is found, then +// getMigration returns the migration for the given version. If no migration is found, then // ErrVersionNotFound is returned. func (p *Provider) getMigration(version int64) (*Migration, error) { for _, m := range p.migrations { diff --git a/provider_run_test.go b/provider_run_test.go index bbb018d4b..efd03fa8b 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -54,13 +54,13 @@ func TestProviderRun(t *testing.T) { p, _ := newProviderWithDB(t) _, err := p.UpTo(context.Background(), 0) check.HasError(t, err) - check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0") + check.Equal(t, err.Error(), "version must be greater than 0") _, err = p.DownTo(context.Background(), -1) check.HasError(t, err) check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1") _, err = p.ApplyVersion(context.Background(), 0, true) check.HasError(t, err) - check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0") + check.Equal(t, err.Error(), "version must be greater than 0") }) t.Run("up_and_down_all", func(t *testing.T) { ctx := context.Background() @@ -488,7 +488,7 @@ func TestNoVersioning(t *testing.T) { ) p, err := goose.NewProvider(database.DialectSQLite3, db, fsys, goose.WithVerbose(testing.Verbose()), - goose.WithDisabledVersioning(false), // This is the default. + goose.WithDisableVersioning(false), // This is the default. ) check.Number(t, len(p.ListSources()), 3) check.NoError(t, err) @@ -501,7 +501,7 @@ func TestNoVersioning(t *testing.T) { fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) p, err := goose.NewProvider(database.DialectSQLite3, db, fsys, goose.WithVerbose(testing.Verbose()), - goose.WithDisabledVersioning(true), // Provider with no versioning. + goose.WithDisableVersioning(true), // Provider with no versioning. ) check.NoError(t, err) check.Number(t, len(p.ListSources()), 2) @@ -1129,14 +1129,14 @@ func getTableNames(db *sql.DB) ([]string, error) { return tables, nil } -func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source goose.Source, appliedIsZero bool) { +func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source *goose.Source, appliedIsZero bool) { t.Helper() check.Equal(t, got.State, state) check.Equal(t, got.Source, source) check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) } -func assertResult(t *testing.T, got *goose.MigrationResult, source goose.Source, direction string, isEmpty bool) { +func assertResult(t *testing.T, got *goose.MigrationResult, source *goose.Source, direction string, isEmpty bool) { t.Helper() check.Bool(t, got != nil, true) check.Equal(t, got.Source, source) @@ -1146,15 +1146,15 @@ func assertResult(t *testing.T, got *goose.MigrationResult, source goose.Source, check.Bool(t, got.Duration > 0, true) } -func assertSource(t *testing.T, got goose.Source, typ goose.MigrationType, name string, version int64) { +func assertSource(t *testing.T, got *goose.Source, typ goose.MigrationType, name string, version int64) { t.Helper() check.Equal(t, got.Type, typ) check.Equal(t, got.Path, name) check.Equal(t, got.Version, version) } -func newSource(t goose.MigrationType, fullpath string, version int64) goose.Source { - return goose.Source{ +func newSource(t goose.MigrationType, fullpath string, version int64) *goose.Source { + return &goose.Source{ Type: t, Path: fullpath, Version: version, diff --git a/types.go b/types.go index 616c7d808..1b4983b2b 100644 --- a/types.go +++ b/types.go @@ -22,7 +22,7 @@ type Source struct { // MigrationResult is the result of a single migration operation. type MigrationResult struct { - Source Source + Source *Source Duration time.Duration Direction string // Empty indicates no action was taken during the migration, but it was still versioned. For @@ -49,7 +49,7 @@ const ( // MigrationStatus represents the status of a single migration. type MigrationStatus struct { - Source Source + Source *Source State State AppliedAt time.Time } From dbb7448aa2b14be0d7d00e5d59f5ec406bb8d31c Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 9 Nov 2023 08:23:50 -0500 Subject: [PATCH 09/13] wip --- provider_options.go | 12 ++++++------ provider_run_test.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/provider_options.go b/provider_options.go index 53edfd340..e13c6a32a 100644 --- a/provider_options.go +++ b/provider_options.go @@ -128,14 +128,14 @@ func WithGoMigrations(migrations ...*Migration) ProviderOption { }) } -// WithAllowedMissing allows the provider to apply missing (out-of-order) migrations. By default, +// WithAllowOutofOrder allows the provider to apply missing (out-of-order) migrations. By default, // goose will raise an error if it encounters a missing migration. // -// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true, -// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of -// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed -// by new migrations. -func WithAllowedMissing(b bool) ProviderOption { +// For example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is +// true, then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order +// of applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, +// followed by new migrations. +func WithAllowOutofOrder(b bool) ProviderOption { return configFunc(func(c *config) error { c.allowMissing = b return nil diff --git a/provider_run_test.go b/provider_run_test.go index efd03fa8b..e51e85586 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -553,7 +553,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_now_allowed", func(t *testing.T) { db := newDB(t) p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), - goose.WithAllowedMissing(false), + goose.WithAllowOutofOrder(false), ) check.NoError(t, err) @@ -608,7 +608,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_allowed", func(t *testing.T) { db := newDB(t) p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), - goose.WithAllowedMissing(true), + goose.WithAllowOutofOrder(true), ) check.NoError(t, err) From 25ab1097bae9c1707035ac06159388492a76606e Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 9 Nov 2023 09:00:06 -0500 Subject: [PATCH 10/13] improve session locker --- lock/postgres.go | 30 ++++++++++------ lock/postgres_test.go | 19 +++++----- lock/session_locker_options.go | 63 ++++++++++++++++++++++++++-------- provider_run_test.go | 13 +++++-- 4 files changed, 88 insertions(+), 37 deletions(-) diff --git a/lock/postgres.go b/lock/postgres.go index 3583162e2..25405b2f1 100644 --- a/lock/postgres.go +++ b/lock/postgres.go @@ -13,17 +13,25 @@ import ( // NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive // session-level advisory lock mechanism. // -// This function creates a SessionLocker that can be used to acquire and release locks for +// This function creates a SessionLocker that can be used to acquire and release a lock for // synchronization purposes. The lock acquisition is retried until it is successfully acquired or -// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the +// until the failure threshold is reached. The default lock duration is set to 5 minutes, and the // default unlock duration is set to 1 minute. // +// If you have long running migrations, you may want to increase the lock duration. +// // See [SessionLockerOption] for options that can be used to configure the SessionLocker. func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) { cfg := sessionLockerConfig{ - lockID: DefaultLockID, - lockTimeout: DefaultLockTimeout, - unlockTimeout: DefaultUnlockTimeout, + lockID: DefaultLockID, + lockProbe: probe{ + periodSeconds: 5 * time.Second, + failureThreshold: 60, + }, + unlockProbe: probe{ + periodSeconds: 2 * time.Second, + failureThreshold: 30, + }, } for _, opt := range opts { if err := opt.apply(&cfg); err != nil { @@ -32,13 +40,13 @@ func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error } return &postgresSessionLocker{ lockID: cfg.lockID, - retryLock: retry.WithMaxDuration( - cfg.lockTimeout, - retry.NewConstant(2*time.Second), + retryLock: retry.WithMaxRetries( + cfg.lockProbe.failureThreshold, + retry.NewConstant(cfg.lockProbe.periodSeconds), ), - retryUnlock: retry.WithMaxDuration( - cfg.unlockTimeout, - retry.NewConstant(2*time.Second), + retryUnlock: retry.WithMaxRetries( + cfg.unlockProbe.failureThreshold, + retry.NewConstant(cfg.unlockProbe.periodSeconds), ), }, nil } diff --git a/lock/postgres_test.go b/lock/postgres_test.go index 2622d5cb6..d57728cf4 100644 --- a/lock/postgres_test.go +++ b/lock/postgres_test.go @@ -6,7 +6,6 @@ import ( "errors" "sync" "testing" - "time" "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/testdb" @@ -30,8 +29,8 @@ func TestPostgresSessionLocker(t *testing.T) { ) locker, err := lock.NewPostgresSessionLocker( lock.WithLockID(lockID), - lock.WithLockTimeout(4*time.Second), - lock.WithUnlockTimeout(4*time.Second), + lock.WithLockTimeout(1, 4), // 4 second timeout + lock.WithUnlockTimeout(1, 4), // 4 second timeout ) check.NoError(t, err) ctx := context.Background() @@ -60,8 +59,8 @@ func TestPostgresSessionLocker(t *testing.T) { }) t.Run("lock_close_conn_unlock", func(t *testing.T) { locker, err := lock.NewPostgresSessionLocker( - lock.WithLockTimeout(4*time.Second), - lock.WithUnlockTimeout(4*time.Second), + lock.WithLockTimeout(1, 4), // 4 second timeout + lock.WithUnlockTimeout(1, 4), // 4 second timeout ) check.NoError(t, err) ctx := context.Background() @@ -103,10 +102,12 @@ func TestPostgresSessionLocker(t *testing.T) { // Exactly one connection should acquire the lock. While the other connections // should fail to acquire the lock and timeout. locker, err := lock.NewPostgresSessionLocker( - lock.WithLockTimeout(4*time.Second), - lock.WithUnlockTimeout(4*time.Second), + lock.WithLockTimeout(1, 4), // 4 second timeout + lock.WithUnlockTimeout(1, 4), // 4 second timeout ) check.NoError(t, err) + // NOTE, we are not unlocking the lock, because we want to test that the lock is + // released when the connection is closed. ch <- locker.SessionLock(ctx, conn) }() } @@ -138,8 +139,8 @@ func TestPostgresSessionLocker(t *testing.T) { ) locker, err := lock.NewPostgresSessionLocker( lock.WithLockID(lockID), - lock.WithLockTimeout(4*time.Second), - lock.WithUnlockTimeout(4*time.Second), + lock.WithLockTimeout(1, 4), // 4 second timeout + lock.WithUnlockTimeout(1, 4), // 4 second timeout ) check.NoError(t, err) diff --git a/lock/session_locker_options.go b/lock/session_locker_options.go index c3e42151c..4f1efe8ff 100644 --- a/lock/session_locker_options.go +++ b/lock/session_locker_options.go @@ -1,6 +1,7 @@ package lock import ( + "errors" "time" ) @@ -10,11 +11,6 @@ const ( // // crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA)) DefaultLockID int64 = 5887940537704921958 - - // Default values for the lock (time to wait for the lock to be acquired) and unlock (time to - // wait for the lock to be released) wait durations. - DefaultLockTimeout time.Duration = 60 * time.Minute - DefaultUnlockTimeout time.Duration = 1 * time.Minute ) // SessionLockerOption is used to configure a SessionLocker. @@ -32,26 +28,65 @@ func WithLockID(lockID int64) SessionLockerOption { }) } -// WithLockTimeout sets the max duration to wait for the lock to be acquired. -func WithLockTimeout(duration time.Duration) SessionLockerOption { +// WithLockTimeout sets the max duration to wait for the lock to be acquired. The total duration +// will be the period times the failure threshold. +// +// By default, the lock timeout is 300s (5min), where the lock is retried every 5 seconds (period) +// up to 60 times (failure threshold). +// +// The minimum period is 1 second, and the minimum failure threshold is 1. +func WithLockTimeout(period, failureThreshold uint64) SessionLockerOption { return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { - c.lockTimeout = duration + if period < 1 { + return errors.New("period must be greater than 0, minimum is 1") + } + if failureThreshold < 1 { + return errors.New("failure threshold must be greater than 0, minimum is 1") + } + c.lockProbe = probe{ + periodSeconds: time.Duration(period) * time.Second, + failureThreshold: failureThreshold, + } return nil }) } -// WithUnlockTimeout sets the max duration to wait for the lock to be released. -func WithUnlockTimeout(duration time.Duration) SessionLockerOption { +// WithUnlockTimeout sets the max duration to wait for the lock to be released. The total duration +// will be the period times the failure threshold. +// +// By default, the lock timeout is 60s, where the lock is retried every 2 seconds (period) up to 30 +// times (failure threshold). +// +// The minimum period is 1 second, and the minimum failure threshold is 1. +func WithUnlockTimeout(period, failureThreshold uint64) SessionLockerOption { return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { - c.unlockTimeout = duration + if period < 1 { + return errors.New("period must be greater than 0, minimum is 1") + } + if failureThreshold < 1 { + return errors.New("failure threshold must be greater than 0, minimum is 1") + } + c.unlockProbe = probe{ + periodSeconds: time.Duration(period) * time.Second, + failureThreshold: failureThreshold, + } return nil }) } type sessionLockerConfig struct { - lockID int64 - lockTimeout time.Duration - unlockTimeout time.Duration + lockID int64 + lockProbe probe + unlockProbe probe +} + +// probe is used to configure how often and how many times to retry a lock or unlock operation. The +// total timeout will be the period times the failure threshold. +type probe struct { + // How often (in seconds) to perform the probe. + periodSeconds time.Duration + // Number of times to retry the probe. + failureThreshold uint64 } var _ SessionLockerOption = (sessionLockerConfigFunc)(nil) diff --git a/provider_run_test.go b/provider_run_test.go index e51e85586..6e00ce9dc 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -830,15 +830,22 @@ func TestLockModeAdvisorySession(t *testing.T) { t.Cleanup(cleanup) newProvider := func() *goose.Provider { - sessionLocker, err := lock.NewPostgresSessionLocker() + + sessionLocker, err := lock.NewPostgresSessionLocker( + lock.WithLockTimeout(5, 60), // Timeout 5min. Try every 5s up to 60 times. + ) check.NoError(t, err) - p, err := goose.NewProvider(database.DialectPostgres, db, os.DirFS("testdata/migrations"), + p, err := goose.NewProvider( + database.DialectPostgres, + db, + os.DirFS("testdata/migrations"), goose.WithSessionLocker(sessionLocker), // Use advisory session lock mode. - goose.WithVerbose(testing.Verbose()), ) check.NoError(t, err) + return p } + provider1 := newProvider() provider2 := newProvider() From fa909d9334f435f9a4b056ee46db9e0b8fc60975 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 9 Nov 2023 09:15:26 -0500 Subject: [PATCH 11/13] wip --- lock/postgres_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/lock/postgres_test.go b/lock/postgres_test.go index d57728cf4..96a621758 100644 --- a/lock/postgres_test.go +++ b/lock/postgres_test.go @@ -180,6 +180,7 @@ func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) { if err != nil { return nil, err } + defer rows.Close() var pgLocks []pgLock for rows.Next() { var p pgLock From a3318c311fc4c0e01ce2f8520ae11eb15f89ca46 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 9 Nov 2023 09:16:53 -0500 Subject: [PATCH 12/13] wip --- types.go => provider_types.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename types.go => provider_types.go (100%) diff --git a/types.go b/provider_types.go similarity index 100% rename from types.go rename to provider_types.go From eaa69c2ed5379a607c4fa4ea4d2f5ba56fcea175 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 9 Nov 2023 09:23:17 -0500 Subject: [PATCH 13/13] wip --- provider_options.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/provider_options.go b/provider_options.go index e13c6a32a..8e3fb2e8a 100644 --- a/provider_options.go +++ b/provider_options.go @@ -100,6 +100,9 @@ func WithExcludeNames(excludes []string) ProviderOption { func WithExcludeVersions(versions []int64) ProviderOption { return configFunc(func(c *config) error { for _, version := range versions { + if version < 1 { + return errInvalidVersion + } if _, ok := c.excludeVersions[version]; ok { return fmt.Errorf("duplicate excludes version: %d", version) }