Skip to content

Commit

Permalink
Refactored code to have a Provider.
Browse files Browse the repository at this point in the history
All migration code is now in a Provider that provides migrations
functionality. This helps isolate various parts of the migrations
and prevents multiple migrations from stepping on top of each other.

This refactor is completely backward-compatible and does not break
the current Public API. This is accomplished by aliasing the Public
functions to a `defaultProvider` that is initialized to the default
settings as it is currently.

Those who don't wish to use the new functionality can do so without
any changes. To use the new system, one will need
to Initialize a new Provider via the `goose.NewProvider()` method which
will return a provider that will be initialized to the default settings,
unless otherwise overwritten.

Note: No new tests were added, and the tests were purposely left* alone
so to ensure that the API did not break.  * except for runMigrationSQL
as that is a private API.

Since behavior did not change, this means all the race conditions that
existed before still exists, just now isolated to the `defailtProvider`.

See:
* pressly#351
* pressly#114
* pressly#114 (comment)
  • Loading branch information
gdey committed Jun 9, 2022
1 parent 2681bbd commit 39df0eb
Show file tree
Hide file tree
Showing 18 changed files with 1,152 additions and 836 deletions.
84 changes: 14 additions & 70 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,29 @@ package goose

import (
"database/sql"
"fmt"
"os"
"path/filepath"
"text/template"
"time"

"github.com/pkg/errors"
)

type tmplVars struct {
Version string
CamelName string
Version string
CamelName string
PackageName string
ProviderVar string
}

var (
sequential = false
)

// SetSequential set whether to use sequential versioning instead of timestamp based versioning
func SetSequential(s bool) {
sequential = s
defaultProvider.SetSequential(s)
}

// Create writes a new blank migration file.
// CreateWithTemplate writes a new blank migration file.
func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, migrationType string) error {
var version string
if sequential {
// always use DirFS here because it's modifying operation
migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion)
if err != nil {
return err
}

vMigrations, err := migrations.versioned()
if err != nil {
return err
}

if last, err := vMigrations.Last(); err == nil {
version = fmt.Sprintf(seqVersionTemplate, last.Version+1)
} else {
version = fmt.Sprintf(seqVersionTemplate, int64(1))
}
} else {
version = time.Now().Format(timestampFormat)
}

filename := fmt.Sprintf("%v_%v.%v", version, snakeCase(name), migrationType)

if tmpl == nil {
if migrationType == "go" {
tmpl = goSQLMigrationTemplate
} else {
tmpl = sqlMigrationTemplate
}
}

path := filepath.Join(dir, filename)
if _, err := os.Stat(path); !os.IsNotExist(err) {
return errors.Wrap(err, "failed to create migration file")
}

f, err := os.Create(path)
if err != nil {
return errors.Wrap(err, "failed to create migration file")
}
defer f.Close()

vars := tmplVars{
Version: version,
CamelName: camelCase(name),
}
if err := tmpl.Execute(f, vars); err != nil {
return errors.Wrap(err, "failed to execute tmpl")
}

log.Printf("Created new file: %s\n", f.Name())
return nil
return defaultProvider.CreateWithTemplate(db, dir, tmpl, name, migrationType)
}

// Create writes a new blank migration file.
func Create(db *sql.DB, dir, name, migrationType string) error {
return CreateWithTemplate(db, dir, nil, name, migrationType)
return defaultProvider.Create(db, dir, name, migrationType)
}

var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`-- +goose Up
Expand All @@ -98,15 +38,19 @@ SELECT 'down SQL query';
-- +goose StatementEnd
`))

var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations
var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package {{.PackageName}}
import (
"database/sql"
"github.com/pressly/goose/v3"
{{if eq .ProviderVar ""}} "github.com/pressly/goose/v3" {{end}}
)
func init() {
{{- if eq .ProviderVar "" }}
goose.AddMigration(up{{.CamelName}}, down{{.CamelName}})
{{- else }}
{{.ProviderVar}}.AddMigration(up{{.CamelName}}, down{{.CamelName}})
{{end }}
}
func up{{.CamelName}}(tx *sql.Tx) error {
Expand Down
21 changes: 1 addition & 20 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,10 @@ package goose

import (
"database/sql"
"fmt"
)

// OpenDBWithDriver creates a connection to a database, and modifies goose
// internals to be compatible with the supplied driver by calling SetDialect.
func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
if err := SetDialect(driver); err != nil {
return nil, err
}

switch driver {
case "mssql":
driver = "sqlserver"
case "redshift":
driver = "postgres"
case "tidb":
driver = "mysql"
}

switch driver {
case "postgres", "pgx", "sqlite3", "sqlite", "mysql", "sqlserver", "clickhouse":
return sql.Open(driver, dbstring)
default:
return nil, fmt.Errorf("unsupported driver %s", driver)
}
return defaultProvider.OpenDBWithDriver(driver, dbstring)
}
Loading

0 comments on commit 39df0eb

Please sign in to comment.