diff --git a/globals.go b/globals.go new file mode 100644 index 000000000..e68bb0c4b --- /dev/null +++ b/globals.go @@ -0,0 +1,87 @@ +package goose + +import ( + "errors" + "fmt" +) + +var ( + registeredGoMigrations = make(map[int64]*Migration) +) + +// ResetGlobalMigrations resets the global go migrations registry. +// +// Not safe for concurrent use. +func ResetGlobalMigrations() { + registeredGoMigrations = make(map[int64]*Migration) +} + +// SetGlobalMigrations registers go migrations globally. It returns an error if a migration with the +// same version has already been registered. +// +// Source may be empty, but if it is set, it must be a path with a numeric component that matches +// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx. +// +// Not safe for concurrent use. +func SetGlobalMigrations(migrations ...Migration) error { + for _, m := range migrations { + // make a copy of the migration so we can modify it without affecting the original. + if err := validGoMigration(&m); err != nil { + return fmt.Errorf("invalid go migration: %w", err) + } + if _, ok := registeredGoMigrations[m.Version]; ok { + return fmt.Errorf("go migration with version %d already registered", m.Version) + } + m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user. + registeredGoMigrations[m.Version] = &m + } + return nil +} + +func validGoMigration(m *Migration) error { + if m == nil { + return errors.New("must not be nil") + } + if !m.Registered { + return errors.New("must be registered") + } + if m.Type != TypeGo { + return fmt.Errorf("type must be %q", TypeGo) + } + if m.Version < 1 { + return errors.New("version must be greater than zero") + } + if m.Source != "" { + // If the source is set, expect it to be a path with a numeric component that matches the + // version. This field is not intended to be used for descriptive purposes. + version, err := NumericComponent(m.Source) + if err != nil { + return err + } + if version != m.Version { + return fmt.Errorf("numeric component [%d] in go migration does not match version in source %q", m.Version, m.Source) + } + } + // It's valid for all of these funcs to be nil. Which means version the go migration but do not + // run anything. + if m.UpFnContext != nil && m.UpFnNoTxContext != nil { + return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext") + } + if m.DownFnContext != nil && m.DownFnNoTxContext != nil { + return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext") + } + // Do not allow legacy functions to be set. + if m.UpFn != nil { + return errors.New("must not specify UpFn") + } + if m.DownFn != nil { + return errors.New("must not specify DownFn") + } + if m.UpFnNoTx != nil { + return errors.New("must not specify UpFnNoTx") + } + if m.DownFnNoTx != nil { + return errors.New("must not specify DownFnNoTx") + } + return nil +} diff --git a/globals_test.go b/globals_test.go new file mode 100644 index 000000000..03febfaf4 --- /dev/null +++ b/globals_test.go @@ -0,0 +1,113 @@ +package goose_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" +) + +func TestGlobalRegister(t *testing.T) { + // Avoid polluting other tests and do not run in parallel. + t.Cleanup(func() { + goose.ResetGlobalMigrations() + }) + fnNoTx := func(context.Context, *sql.DB) error { return nil } + fn := func(context.Context, *sql.Tx) error { return nil } + + // Success. + err := goose.SetGlobalMigrations( + []goose.Migration{}..., + ) + check.NoError(t, err) + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo, UpFnContext: fn}, + ) + check.NoError(t, err) + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "go migration with version 1 already registered") + err = goose.SetGlobalMigrations( + goose.Migration{ + Registered: true, + Version: 2, + Source: "00002_foo.sql", + Type: goose.TypeGo, + UpFnContext: func(context.Context, *sql.Tx) error { return nil }, + DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, + }, + ) + check.NoError(t, err) + // Reset. + { + goose.ResetGlobalMigrations() + } + // Failure. + err = goose.SetGlobalMigrations( + goose.Migration{}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: must be registered") + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), `invalid go migration: type must be "go"`) + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, Type: goose.TypeSQL}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), `invalid go migration: type must be "go"`) + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 0, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero") + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), `invalid go migration: numeric component [1] in go migration does not match version in source "2_foo.sql"`) + // Legacy functions. + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn") + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn") + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx") + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx") + // Context-aware functions. + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, UpFnContext: fn, UpFnNoTxContext: fnNoTx, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext") + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, DownFnContext: fn, DownFnNoTxContext: fnNoTx, Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext") + // Source and version mismatch. + err = goose.SetGlobalMigrations( + goose.Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: goose.TypeGo}, + ) + check.HasError(t, err) + check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`) +} diff --git a/migrate.go b/migrate.go index dc9a1a01e..599810ce8 100644 --- a/migrate.go +++ b/migrate.go @@ -23,8 +23,6 @@ var ( ErrNoNextVersion = errors.New("no next version found") // MaxVersion is the maximum allowed version. MaxVersion int64 = math.MaxInt64 - - registeredGoMigrations = map[int64]*Migration{} ) // Migrations slice. diff --git a/migration.go b/migration.go index 619e934d0..d81e5893e 100644 --- a/migration.go +++ b/migration.go @@ -20,25 +20,27 @@ type MigrationRecord struct { IsApplied bool // was this a result of up() or down() } -// Migration struct. +// Migration struct represents either a SQL or Go migration. type Migration struct { - Version int64 - Next int64 // next version, or -1 if none - Previous int64 // previous version, -1 if none - Source string // path to .sql script or go file - Registered bool - UseTx bool - - // These are deprecated and will be removed in the future. - // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. - // Goose does not use these internally anymore and instead uses the context versions. + Type MigrationType + Version int64 + Source string // path to .sql script or .go file + Registered bool + UpFnContext, DownFnContext GoMigrationContext + UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext + + // These fields will be removed in a future major version. They are here for backwards + // compatibility and are an implementation detail. + UseTx bool + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + + // We still save the non-context versions in the struct in case someone is using them. Goose + // does not use these internally anymore in favor of the context-aware versions. UpFn, DownFn GoMigration UpFnNoTx, DownFnNoTx GoMigrationNoTx - // New functions with context - UpFnContext, DownFnContext GoMigrationContext - UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext - noVersioning bool + noVersioning bool } func (m *Migration) String() string { @@ -233,7 +235,7 @@ func NumericComponent(filename string) (int64, error) { } n, err := strconv.ParseInt(base[:idx], 10, 64) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to parse version from migration file: %s: %w", base, err) } if n < 1 { return 0, errors.New("migration version must be greater than zero") diff --git a/types.go b/types.go new file mode 100644 index 000000000..f0008d09e --- /dev/null +++ b/types.go @@ -0,0 +1,17 @@ +package goose + +// MigrationType is the type of migration. +type MigrationType string + +const ( + TypeGo MigrationType = "go" + TypeSQL MigrationType = "sql" +) + +func (t MigrationType) String() string { + // This should never happen. + if t == "" { + return "unknown migration type" + } + return string(t) +}