Skip to content

Commit

Permalink
Refactored code to have a Provider.
Browse files Browse the repository at this point in the history
All migration code is now in a Provider that provides migrations
functionality. This helps isolate various parts of the migrations
and prevents multiple migrations from stepping on top of each other.

This refactor is completely backward-compatible and does not break
the current Public API. This is accomplished by aliasing the Public
functions to a `defaultProvider` that is initialized to the default
settings as it is currently.

Those who don't wish to use the new functionality can do so without
any changes. To use the new system, one will need
to Initialize a new Provider via the `goose.NewProvider()` method which
will return a provider that will be initialized to the default settings,
unless otherwise overwritten.

Note: No new tests were added, and the tests were purposely left* alone
so to ensure that the API did not break.  * except for runMigrationSQL
as that is a private API.

Since behavior did not change, this means all the race conditions that
existed before still exists, just now isolated to the `defailtProvider`.

See:
* pressly#351
* pressly#114
* pressly#114 (comment)
  • Loading branch information
gdey committed Jun 25, 2022
1 parent c740b81 commit 92057b1
Show file tree
Hide file tree
Showing 19 changed files with 640 additions and 300 deletions.
60 changes: 41 additions & 19 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,38 @@ import (
)

type tmplVars struct {
Version string
CamelName string
Version string
CamelName string
PackageName string
ProviderVar string
}

var (
sequential = false
)

// SetSequential set whether to use sequential versioning instead of timestamp based versioning
func SetSequential(s bool) {
sequential = s
defaultProvider.SetSequential(s)
}

// Create writes a new blank migration file.
// SetSequential set's whether to use sequential versioning instead of timestamp based versioning
func (p *Provider) SetSequential(s bool) { p.sequential = s }

// CreateWithTemplate writes a new blank migration file.
func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, migrationType string) error {
var version string
if sequential {
return defaultProvider.CreateWithTemplate(db, dir, tmpl, name, migrationType)
}

// CreateWithTemplate writes a new blank migration file.
func (p *Provider) CreateWithTemplate(_ *sql.DB, dir string, tmpl *template.Template, name, migrationType string) error {
timefn := p.timeFn
if p.timeFn == nil {
timefn = time.Now
}
version := timefn().Format(p.timestampFormat)
if p.baseDir != "" && (dir == "" || dir == ".") {
dir = p.baseDir
}
if p.sequential {
// always use DirFS here because it's modifying operation
migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion)
migrations, err := p.collectMigrationsFS(osFS{}, dir, minVersion, maxVersion)
if err != nil {
return err
}
Expand All @@ -43,8 +56,6 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m
} else {
version = fmt.Sprintf(seqVersionTemplate, int64(1))
}
} else {
version = time.Now().Format(timestampFormat)
}

filename := fmt.Sprintf("%v_%v.%v", version, snakeCase(name), migrationType)
Expand All @@ -69,20 +80,27 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m
defer f.Close()

vars := tmplVars{
Version: version,
CamelName: camelCase(name),
PackageName: p.packageName,
ProviderVar: p.providerVarName,
Version: version,
CamelName: camelCase(name),
}
if err := tmpl.Execute(f, vars); err != nil {
return fmt.Errorf("failed to execute tmpl: %w", err)
}

log.Printf("Created new file: %s\n", f.Name())
p.log.Printf("Created new file: %s\n", f.Name())
return nil
}

// Create writes a new blank migration file.
func Create(db *sql.DB, dir, name, migrationType string) error {
return CreateWithTemplate(db, dir, nil, name, migrationType)
return defaultProvider.Create(db, dir, name, migrationType)
}

// Create writes a new blank migration file.
func (p *Provider) Create(db *sql.DB, dir, name, migrationType string) error {
return p.CreateWithTemplate(db, dir, nil, name, migrationType)
}

var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`-- +goose Up
Expand All @@ -96,15 +114,19 @@ SELECT 'down SQL query';
-- +goose StatementEnd
`))

var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations
var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package {{.PackageName}}
import (
"database/sql"
"github.com/pressly/goose/v3"
{{if eq .ProviderVar ""}} "github.com/pressly/goose/v3" {{end}}
)
func init() {
{{- if eq .ProviderVar "" }}
goose.AddMigration(up{{.CamelName}}, down{{.CamelName}})
{{- else }}
{{.ProviderVar}}.AddMigration(up{{.CamelName}}, down{{.CamelName}})
{{end }}
}
func up{{.CamelName}}(tx *sql.Tx) error {
Expand Down
8 changes: 7 additions & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ import (
// OpenDBWithDriver creates a connection to a database, and modifies goose
// internals to be compatible with the supplied driver by calling SetDialect.
func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
if err := SetDialect(driver); err != nil {
return defaultProvider.OpenDBWithDriver(driver, dbstring)
}

// OpenDBWithDriver creates a connection to a database, and modifies goose
// internals to be compatible with the supplied driver by calling SetDialect.
func (p *Provider) OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
if err := p.SetDialect(driver); err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit 92057b1

Please sign in to comment.