Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expose setter for global Go registry #625

Merged
merged 2 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions globals.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package goose

import (
"errors"
"fmt"
)

var (
registeredGoMigrations = make(map[int64]*Migration)
)

// ResetGlobalMigrations resets the global go migrations registry.
//
// Not safe for concurrent use.
func ResetGlobalMigrations() {
registeredGoMigrations = make(map[int64]*Migration)
}

// SetGlobalMigrations registers go migrations globally. It returns an error if a migration with the
// same version has already been registered.
//
// Source may be empty, but if it is set, it must be a path with a numeric component that matches
// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...Migration) error {
for _, m := range migrations {
// make a copy of the migration so we can modify it without affecting the original.
if err := validGoMigration(&m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user.
registeredGoMigrations[m.Version] = &m
}
return nil
}

func validGoMigration(m *Migration) error {
if m == nil {
return errors.New("must not be nil")
}
if !m.Registered {
return errors.New("must be registered")
}
if m.Type != TypeGo {
return fmt.Errorf("type must be %q", TypeGo)
}
if m.Version < 1 {
return errors.New("version must be greater than zero")
}
if m.Source != "" {
// If the source is set, expect it to be a path with a numeric component that matches the
// version. This field is not intended to be used for descriptive purposes.
version, err := NumericComponent(m.Source)
if err != nil {
return err
}
if version != m.Version {
return fmt.Errorf("numeric component [%d] in go migration does not match version in source %q", m.Version, m.Source)
}
}
// It's valid for all of these funcs to be nil. Which means version the go migration but do not
// run anything.
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
// Do not allow legacy functions to be set.
if m.UpFn != nil {
return errors.New("must not specify UpFn")
}
if m.DownFn != nil {
return errors.New("must not specify DownFn")
}
if m.UpFnNoTx != nil {
return errors.New("must not specify UpFnNoTx")
}
if m.DownFnNoTx != nil {
return errors.New("must not specify DownFnNoTx")
}
return nil
}
113 changes: 113 additions & 0 deletions globals_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package goose_test

import (
"context"
"database/sql"
"testing"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
)

func TestGlobalRegister(t *testing.T) {
// Avoid polluting other tests and do not run in parallel.
t.Cleanup(func() {
goose.ResetGlobalMigrations()
})
fnNoTx := func(context.Context, *sql.DB) error { return nil }
fn := func(context.Context, *sql.Tx) error { return nil }

// Success.
err := goose.SetGlobalMigrations(
[]goose.Migration{}...,
)
check.NoError(t, err)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo, UpFnContext: fn},
)
check.NoError(t, err)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 already registered")
err = goose.SetGlobalMigrations(
goose.Migration{
Registered: true,
Version: 2,
Source: "00002_foo.sql",
Type: goose.TypeGo,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
},
)
check.NoError(t, err)
// Reset.
{
goose.ResetGlobalMigrations()
}
// Failure.
err = goose.SetGlobalMigrations(
goose.Migration{},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must be registered")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeSQL},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 0, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: numeric component [1] in go migration does not match version in source "2_foo.sql"`)
// Legacy functions.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx")
// Context-aware functions.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFnContext: fn, UpFnNoTxContext: fnNoTx, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFnContext: fn, DownFnNoTxContext: fnNoTx, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext")
// Source and version mismatch.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`)
}
2 changes: 0 additions & 2 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ var (
ErrNoNextVersion = errors.New("no next version found")
// MaxVersion is the maximum allowed version.
MaxVersion int64 = math.MaxInt64

registeredGoMigrations = map[int64]*Migration{}
)

// Migrations slice.
Expand Down
34 changes: 18 additions & 16 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,27 @@ type MigrationRecord struct {
IsApplied bool // was this a result of up() or down()
}

// Migration struct.
// Migration struct represents either a SQL or Go migration.
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file
Registered bool
UseTx bool

// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
Type MigrationType
Version int64
Source string // path to .sql script or .go file
Registered bool
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext

// These fields will be removed in a future major version. They are here for backwards
// compatibility and are an implementation detail.
UseTx bool
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none

// We still save the non-context versions in the struct in case someone is using them. Goose
// does not use these internally anymore in favor of the context-aware versions.
UpFn, DownFn GoMigration
UpFnNoTx, DownFnNoTx GoMigrationNoTx

// New functions with context
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
noVersioning bool
noVersioning bool
}

func (m *Migration) String() string {
Expand Down Expand Up @@ -233,7 +235,7 @@ func NumericComponent(filename string) (int64, error) {
}
n, err := strconv.ParseInt(base[:idx], 10, 64)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to parse version from migration file: %s: %w", base, err)
}
if n < 1 {
return 0, errors.New("migration version must be greater than zero")
Expand Down
17 changes: 17 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package goose

// MigrationType is the type of migration.
type MigrationType string

const (
TypeGo MigrationType = "go"
TypeSQL MigrationType = "sql"
)

func (t MigrationType) String() string {
// This should never happen.
if t == "" {
return "unknown migration type"
}
return string(t)
}
Loading