Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add goose provider #635

Merged
merged 13 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions database/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"database/sql"
"errors"
"fmt"

Expand Down Expand Up @@ -100,6 +101,9 @@ func (s *store) GetMigration(
&result.Timestamp,
&result.IsApplied,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
}
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
}
return &result, nil
Expand Down
10 changes: 8 additions & 2 deletions database/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@ package database

import (
"context"
"errors"
"time"
)

var (
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found.
ErrVersionNotFound = errors.New("version not found")
)

// Store is an interface that defines methods for managing database migrations and versioning. By
// defining a Store interface, we can support multiple databases with consistent functionality.
//
Expand All @@ -24,8 +30,8 @@ type Store interface {
// Delete deletes a version id from the version table.
Delete(ctx context.Context, db DBTxConn, version int64) error

// GetMigration retrieves a single migration by version id. This method may return the raw sql
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
// GetMigration retrieves a single migration by version id. If the query succeeds, but the
// version is not found, this method must return [ErrVersionNotFound].
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
Expand Down
2 changes: 1 addition & 1 deletion database/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func testStore(
err = runConn(ctx, db, func(conn *sql.Conn) error {
_, err := store.GetMigration(ctx, conn, 0)
check.HasError(t, err)
check.Bool(t, errors.Is(err, sql.ErrNoRows), true)
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
return nil
})
check.NoError(t, err)
Expand Down
60 changes: 17 additions & 43 deletions globals.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,20 @@ func ResetGlobalMigrations() {
// [NewGoMigration] function.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...Migration) error {
for _, migration := range migrations {
m := &migration
func SetGlobalMigrations(migrations ...*Migration) error {
for _, m := range migrations {
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
if err := checkMigration(m); err != nil {
if err := checkGoMigration(m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
registeredGoMigrations[m.Version] = m
}
return nil
}

func checkMigration(m *Migration) error {
func checkGoMigration(m *Migration) error {
if !m.construct {
return errors.New("must use NewGoMigration to construct migrations")
}
Expand All @@ -63,10 +62,10 @@ func checkMigration(m *Migration) error {
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
}
}
if err := setGoFunc(m.goUp); err != nil {
if err := checkGoFunc(m.goUp); err != nil {
return fmt.Errorf("up function: %w", err)
}
if err := setGoFunc(m.goDown); err != nil {
if err := checkGoFunc(m.goDown); err != nil {
return fmt.Errorf("down function: %w", err)
}
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
Expand All @@ -84,47 +83,22 @@ func checkMigration(m *Migration) error {
return nil
}

func setGoFunc(f *GoFunc) error {
if f == nil {
f = &GoFunc{Mode: TransactionEnabled}
return nil
}
func checkGoFunc(f *GoFunc) error {
if f.RunTx != nil && f.RunDB != nil {
return errors.New("must specify exactly one of RunTx or RunDB")
}
if f.RunTx == nil && f.RunDB == nil {
switch f.Mode {
case 0:
// Default to TransactionEnabled ONLY if mode is not set explicitly.
f.Mode = TransactionEnabled
case TransactionEnabled, TransactionDisabled:
// No functions but mode is set. This is not an error. It means the user wants to record
// a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
return nil
}
if f.RunDB != nil {
switch f.Mode {
case 0, TransactionDisabled:
f.Mode = TransactionDisabled
default:
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
switch f.Mode {
case TransactionEnabled, TransactionDisabled:
// No functions, but mode is set. This is not an error. It means the user wants to
// record a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
if f.RunTx != nil {
switch f.Mode {
case 0, TransactionEnabled:
f.Mode = TransactionEnabled
default:
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
if f.RunDB != nil && f.Mode != TransactionDisabled {
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
// the functions or return an error. This should never happen.
if f.Mode == 0 {
return errors.New("failed to infer transaction mode")
if f.RunTx != nil && f.Mode != TransactionEnabled {
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
return nil
}
63 changes: 41 additions & 22 deletions globals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ func TestTransactionMode(t *testing.T) {
// reset so we can check the default is set
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
err = SetGlobalMigrations(migration2)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 2)
registered = registeredGoMigrations[2]
check.Bool(t, registered.goUp != nil, true)
check.Bool(t, registered.goDown != nil, true)
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")

migration3 := NewGoMigration(3, nil, nil)
// reset so we can check the default is set
migration3.goDown.Mode = 0
err = SetGlobalMigrations(migration3)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
})
t.Run("unknown_mode", func(t *testing.T) {
m := NewGoMigration(1, nil, nil)
Expand Down Expand Up @@ -192,7 +194,7 @@ func TestGlobalRegister(t *testing.T) {
runTx := func(context.Context, *sql.Tx) error { return nil }

// Success.
err := SetGlobalMigrations([]Migration{}...)
err := SetGlobalMigrations([]*Migration{}...)
check.NoError(t, err)
err = SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
Expand All @@ -204,62 +206,79 @@ func TestGlobalRegister(t *testing.T) {
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 already registered")
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
}

func TestCheckMigration(t *testing.T) {
// Success.
err := checkGoMigration(NewGoMigration(1, nil, nil))
check.NoError(t, err)
// Failures.
err := checkMigration(&Migration{})
err = checkGoMigration(&Migration{})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
err = checkMigration(&Migration{construct: true})
err = checkGoMigration(&Migration{construct: true})
check.HasError(t, err)
check.Contains(t, err.Error(), "must be registered")
err = checkMigration(&Migration{construct: true, Registered: true})
err = checkGoMigration(&Migration{construct: true, Registered: true})
check.HasError(t, err)
check.Contains(t, err.Error(), `type must be "go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "version must be greater than zero")
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
check.HasError(t, err)
check.Contains(t, err.Error(), "up function: invalid mode: 0")
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
check.HasError(t, err)
check.Contains(t, err.Error(), "down function: invalid mode: 0")
// Success.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
check.NoError(t, err)
// Failures.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `no filename separator '_' found`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFn: func(*sql.Tx) error { return nil },
UpFnNoTx: func(*sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFn: func(*sql.Tx) error { return nil },
DownFnNoTx: func(*sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
Expand Down
Loading
Loading