diff --git a/create.go b/create.go index d9ec002f3..db47a2e4c 100644 --- a/create.go +++ b/create.go @@ -30,7 +30,7 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m if sequential { // always use DirFS here because it's modifying operation - migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) + migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion, registeredGoMigrations) if err != nil && !errors.Is(err, ErrNoMigrationFiles) { return err } diff --git a/fix.go b/fix.go index 7bc7ed5d6..a498358ee 100644 --- a/fix.go +++ b/fix.go @@ -11,7 +11,7 @@ const seqVersionTemplate = "%05v" func Fix(dir string) error { // always use osFS here because it's modifying operation - migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) + migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion, registeredGoMigrations) if err != nil { return err } diff --git a/goose_cli_test.go b/goose_cli_test.go new file mode 100644 index 000000000..020e48023 --- /dev/null +++ b/goose_cli_test.go @@ -0,0 +1,214 @@ +package goose_test + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" + "time" + + "github.com/pressly/goose/v3/internal/check" + _ "modernc.org/sqlite" +) + +const ( + // gooseTestBinaryVersion is utilized in conjunction with a linker variable to set the version + // of a binary created solely for testing purposes. It is used to test the --version flag. + gooseTestBinaryVersion = "v0.0.0" +) + +func TestFullBinary(t *testing.T) { + t.Parallel() + cli := buildGooseCLI(t, false) + out, err := cli.run("--version") + check.NoError(t, err) + check.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n") +} + +func TestLiteBinary(t *testing.T) { + t.Parallel() + cli := buildGooseCLI(t, true) + + t.Run("binary_version", func(t *testing.T) { + t.Parallel() + out, err := cli.run("--version") + check.NoError(t, err) + check.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n") + }) + t.Run("default_binary", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + total := countSQLFiles(t, "testdata/migrations") + commands := []struct { + cmd string + out string + }{ + {"up", "goose: successfully migrated database to version: " + strconv.Itoa(total)}, + {"version", "goose: version " + strconv.Itoa(total)}, + {"down", "OK"}, + {"version", "goose: version " + strconv.Itoa(total-1)}, + {"status", ""}, + {"reset", "OK"}, + {"version", "goose: version 0"}, + } + for _, c := range commands { + out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd) + check.NoError(t, err) + check.Contains(t, out, c.out) + } + }) + t.Run("gh_issue_532", func(t *testing.T) { + // https://github.com/pressly/goose/issues/532 + t.Parallel() + dir := t.TempDir() + total := countSQLFiles(t, "testdata/migrations") + _, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up") + check.NoError(t, err) + out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up") + check.NoError(t, err) + check.Contains(t, out, "goose: no migrations to run. current version: "+strconv.Itoa(total)) + out, err = cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "version") + check.NoError(t, err) + check.Contains(t, out, "goose: version "+strconv.Itoa(total)) + }) + t.Run("gh_issue_293", func(t *testing.T) { + // https://github.com/pressly/goose/issues/293 + t.Parallel() + dir := t.TempDir() + total := countSQLFiles(t, "testdata/migrations") + commands := []struct { + cmd string + out string + }{ + {"up", "goose: successfully migrated database to version: " + strconv.Itoa(total)}, + {"version", "goose: version " + strconv.Itoa(total)}, + {"down", "OK"}, + {"down", "OK"}, + {"version", "goose: version " + strconv.Itoa(total-2)}, + {"up", "goose: successfully migrated database to version: " + strconv.Itoa(total)}, + {"status", ""}, + } + for _, c := range commands { + out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd) + check.NoError(t, err) + check.Contains(t, out, c.out) + } + }) + t.Run("gh_issue_336", func(t *testing.T) { + // https://github.com/pressly/goose/issues/336 + t.Parallel() + dir := t.TempDir() + _, err := cli.run("-dir="+dir, "sqlite3", filepath.Join(dir, "sql.db"), "up") + check.HasError(t, err) + check.Contains(t, err.Error(), "goose run: no migration files found") + }) + t.Run("create_and_fix", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + createEmptyFile(t, dir, "00001_alpha.sql") + createEmptyFile(t, dir, "00003_bravo.sql") + createEmptyFile(t, dir, "20230826163141_charlie.sql") + createEmptyFile(t, dir, "20230826163151_delta.go") + total, err := os.ReadDir(dir) + check.NoError(t, err) + check.Number(t, len(total), 4) + migrationFiles := []struct { + name string + fileType string + }{ + {"echo", "sql"}, + {"foxtrot", "go"}, + {"golf", ""}, + } + for i, f := range migrationFiles { + args := []string{"-dir=" + dir, "create", f.name} + if f.fileType != "" { + args = append(args, f.fileType) + } + out, err := cli.run(args...) + check.NoError(t, err) + check.Contains(t, out, "Created new file") + // ensure different timestamps, granularity is 1 second + if i < len(migrationFiles)-1 { + time.Sleep(1100 * time.Millisecond) + } + } + total, err = os.ReadDir(dir) + check.NoError(t, err) + check.Number(t, len(total), 7) + out, err := cli.run("-dir="+dir, "fix") + check.NoError(t, err) + check.Contains(t, out, "RENAMED") + files, err := os.ReadDir(dir) + check.NoError(t, err) + check.Number(t, len(files), 7) + expected := []string{ + "00001_alpha.sql", + "00003_bravo.sql", + "00004_charlie.sql", + "00005_delta.go", + "00006_echo.sql", + "00007_foxtrot.go", + "00008_golf.go", + } + for i, f := range files { + check.Equal(t, f.Name(), expected[i]) + } + }) +} + +type gooseBinary struct { + binaryPath string +} + +func (g gooseBinary) run(params ...string) (string, error) { + cmd := exec.Command(g.binaryPath, params...) + out, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to run goose command: %v\nout: %v", err, string(out)) + } + return string(out), nil +} + +// buildGooseCLI builds goose test binary, which is used for testing goose CLI. It is built with all +// drivers enabled, unless lite is true, in which case all drivers are disabled except sqlite3 +func buildGooseCLI(t *testing.T, lite bool) gooseBinary { + binName := "goose-test" + dir := t.TempDir() + output := filepath.Join(dir, binName) + // usage: go build [-o output] [build flags] [packages] + args := []string{ + "build", + "-o", output, + "-ldflags=-s -w -X main.version=" + gooseTestBinaryVersion, + } + if lite { + args = append(args, "-tags=no_clickhouse no_mssql no_mysql no_vertica no_postgres") + } + args = append(args, "./cmd/goose") + build := exec.Command("go", args...) + out, err := build.CombinedOutput() + if err != nil { + t.Fatalf("failed to build %s binary: %v: %s", binName, err, string(out)) + } + return gooseBinary{ + binaryPath: output, + } +} + +func countSQLFiles(t *testing.T, dir string) int { + t.Helper() + files, err := filepath.Glob(filepath.Join(dir, "*.sql")) + check.NoError(t, err) + return len(files) +} + +func createEmptyFile(t *testing.T, dir, name string) { + t.Helper() + path := filepath.Join(dir, name) + f, err := os.Create(path) + check.NoError(t, err) + defer f.Close() +} diff --git a/goose_embed_test.go b/goose_embed_test.go new file mode 100644 index 000000000..a09f87992 --- /dev/null +++ b/goose_embed_test.go @@ -0,0 +1,62 @@ +package goose_test + +import ( + "database/sql" + "embed" + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" + _ "modernc.org/sqlite" +) + +//go:embed testdata/migrations/*.sql +var embedMigrations embed.FS + +func TestEmbeddedMigrations(t *testing.T) { + dir := t.TempDir() + // not using t.Parallel here to avoid races + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + + db.SetMaxOpenConns(1) + + migrationFiles, err := fs.ReadDir(embedMigrations, "testdata/migrations") + check.NoError(t, err) + total := len(migrationFiles) + + // decouple from existing structure + fsys, err := fs.Sub(embedMigrations, "testdata/migrations") + check.NoError(t, err) + + goose.SetBaseFS(fsys) + t.Cleanup(func() { goose.SetBaseFS(nil) }) + check.NoError(t, goose.SetDialect("sqlite3")) + + t.Run("migration_cycle", func(t *testing.T) { + err := goose.Up(db, ".") + check.NoError(t, err) + ver, err := goose.GetDBVersion(db) + check.NoError(t, err) + check.Number(t, ver, total) + err = goose.Reset(db, ".") + check.NoError(t, err) + ver, err = goose.GetDBVersion(db) + check.NoError(t, err) + check.Number(t, ver, 0) + }) + t.Run("create_uses_os_fs", func(t *testing.T) { + dir := t.TempDir() + err := goose.Create(db, dir, "test", "sql") + check.NoError(t, err) + paths, _ := filepath.Glob(filepath.Join(dir, "*test.sql")) + check.NumberNotZero(t, len(paths)) + err = goose.Fix(dir) + check.NoError(t, err) + _, err = os.Stat(filepath.Join(dir, "00001_test.sql")) + check.NoError(t, err) + }) +} diff --git a/goose_test.go b/goose_test.go deleted file mode 100644 index 3952af9d4..000000000 --- a/goose_test.go +++ /dev/null @@ -1,330 +0,0 @@ -package goose - -import ( - "database/sql" - "embed" - "fmt" - "io/fs" - "os" - "os/exec" - "path/filepath" - "runtime" - "strings" - "testing" - - "github.com/pressly/goose/v3/internal/check" - _ "modernc.org/sqlite" -) - -const ( - binName = "goose-test" -) - -func TestMain(m *testing.M) { - if runtime.GOOS == "windows" { - log.Fatalf("this test is not supported on Windows") - } - dir, err := os.Getwd() - if err != nil { - log.Fatalf("%v", err) - } - args := []string{ - "build", - "-ldflags=-s -w", - // disable all drivers except sqlite3 - "-tags=no_clickhouse no_mssql no_mysql no_vertica no_postgres", - "-o", binName, - "./cmd/goose", - } - build := exec.Command("go", args...) - out, err := build.CombinedOutput() - if err != nil { - log.Fatalf("failed to build %s binary: %v: %s", binName, err, string(out)) - } - result := m.Run() - defer func() { os.Exit(result) }() - if err := os.Remove(filepath.Join(dir, binName)); err != nil { - log.Printf("failed to remove binary: %v", err) - } -} - -func TestDefaultBinary(t *testing.T) { - t.Parallel() - - commands := []string{ - "go build -o ./bin/goose ./cmd/goose", - "./bin/goose -dir=examples/sql-migrations sqlite3 sql.db up", - "./bin/goose -dir=examples/sql-migrations sqlite3 sql.db version", - "./bin/goose -dir=examples/sql-migrations sqlite3 sql.db down", - "./bin/goose -dir=examples/sql-migrations sqlite3 sql.db status", - "./bin/goose --version", - } - t.Cleanup(func() { - if err := os.Remove("./bin/goose"); err != nil { - t.Logf("failed to remove %s resources: %v", t.Name(), err) - } - if err := os.Remove("./sql.db"); err != nil { - t.Logf("failed to remove %s resources: %v", t.Name(), err) - } - }) - - for _, cmd := range commands { - args := strings.Split(cmd, " ") - command := args[0] - var params []string - if len(args) > 1 { - params = args[1:] - } - - cmd := exec.Command(command, params...) - cmd.Env = os.Environ() - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) - } - } -} - -func TestIssue532(t *testing.T) { - t.Parallel() - - migrationsDir := filepath.Join("examples", "sql-migrations") - count := countSQLFiles(t, migrationsDir) - check.Number(t, count, 3) - - tempDir := t.TempDir() - dirFlag := "--dir=" + migrationsDir - - tt := []struct { - command string - output string - }{ - {"up", ""}, - {"up", "goose: no migrations to run. current version: 3"}, - {"version", "goose: version 3"}, - } - for _, tc := range tt { - params := []string{dirFlag, "sqlite3", filepath.Join(tempDir, "sql.db"), tc.command} - got, err := runGoose(params...) - check.NoError(t, err) - if tc.output == "" { - continue - } - if !strings.Contains(strings.TrimSpace(got), tc.output) { - t.Logf("output mismatch for command: %q", tc.command) - t.Logf("got\n%s", strings.TrimSpace(got)) - t.Log("====") - t.Logf("want\n%s", tc.output) - t.FailNow() - } - } -} - -func TestIssue293(t *testing.T) { - t.Parallel() - // https://github.com/pressly/goose/issues/293 - commands := []string{ - "go build -o ./bin/goose293 ./cmd/goose", - "./bin/goose293 -dir=examples/sql-migrations sqlite3 issue_293.db up", - "./bin/goose293 -dir=examples/sql-migrations sqlite3 issue_293.db version", - "./bin/goose293 -dir=examples/sql-migrations sqlite3 issue_293.db down", - "./bin/goose293 -dir=examples/sql-migrations sqlite3 issue_293.db down", - "./bin/goose293 -dir=examples/sql-migrations sqlite3 issue_293.db version", - "./bin/goose293 -dir=examples/sql-migrations sqlite3 issue_293.db up", - "./bin/goose293 -dir=examples/sql-migrations sqlite3 issue_293.db status", - } - t.Cleanup(func() { - if err := os.Remove("./bin/goose293"); err != nil { - t.Logf("failed to remove %s resources: %v", t.Name(), err) - } - if err := os.Remove("./issue_293.db"); err != nil { - t.Logf("failed to remove %s resources: %v", t.Name(), err) - } - }) - for _, cmd := range commands { - args := strings.Split(cmd, " ") - command := args[0] - var params []string - if len(args) > 1 { - params = args[1:] - } - - cmd := exec.Command(command, params...) - cmd.Env = os.Environ() - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) - } - } -} - -func TestIssue336(t *testing.T) { - t.Parallel() - // error when no migrations are found - // https://github.com/pressly/goose/issues/336 - - tempDir := t.TempDir() - params := []string{"--dir=" + tempDir, "sqlite3", filepath.Join(tempDir, "sql.db"), "up"} - - _, err := runGoose(params...) - check.HasError(t, err) - check.Contains(t, err.Error(), "no migration files found") -} - -func TestLiteBinary(t *testing.T) { - t.Parallel() - - dir := t.TempDir() - t.Cleanup(func() { - if err := os.Remove("./bin/lite-goose"); err != nil { - t.Logf("failed to remove %s resources: %v", t.Name(), err) - } - }) - - // this has to be done outside of the loop - // since go only supports space separated tags list. - cmd := exec.Command("go", "build", "-tags='no_postgres no_mysql no_sqlite3'", "-o", "./bin/lite-goose", "./cmd/goose") - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) - } - - commands := []string{ - fmt.Sprintf("./bin/lite-goose -dir=%s create user_indices sql", dir), - fmt.Sprintf("./bin/lite-goose -dir=%s fix", dir), - } - - for _, cmd := range commands { - args := strings.Split(cmd, " ") - cmd := exec.Command(args[0], args[1:]...) - cmd.Env = os.Environ() - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) - } - } -} - -func TestCustomBinary(t *testing.T) { - t.Parallel() - - commands := []string{ - "go build -o ./bin/custom-goose ./examples/go-migrations", - "./bin/custom-goose -dir=examples/go-migrations sqlite3 go.db up", - "./bin/custom-goose -dir=examples/go-migrations sqlite3 go.db version", - "./bin/custom-goose -dir=examples/go-migrations sqlite3 go.db down", - "./bin/custom-goose -dir=examples/go-migrations sqlite3 go.db status", - } - t.Cleanup(func() { - if err := os.Remove("./go.db"); err != nil { - t.Logf("failed to remove %s resources: %v", t.Name(), err) - } - }) - - for _, cmd := range commands { - args := strings.Split(cmd, " ") - out, err := exec.Command(args[0], args[1:]...).CombinedOutput() - if err != nil { - t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) - } - } -} - -//go:embed examples/sql-migrations/*.sql -var migrations embed.FS - -func TestEmbeddedMigrations(t *testing.T) { - // not using t.Parallel here to avoid races - db, err := sql.Open("sqlite", "sql_embed.db") - if err != nil { - t.Fatalf("Database open failed: %s", err) - } - t.Cleanup(func() { - if err := os.Remove("./sql_embed.db"); err != nil { - t.Logf("failed to remove %s resources: %v", t.Name(), err) - } - }) - - db.SetMaxOpenConns(1) - - // decouple from existing structure - fsys, err := fs.Sub(migrations, "examples/sql-migrations") - if err != nil { - t.Fatalf("SubFS make failed: %s", err) - } - - SetBaseFS(fsys) - check.NoError(t, SetDialect("sqlite3")) - t.Cleanup(func() { SetBaseFS(nil) }) - - t.Run("Migration cycle", func(t *testing.T) { - if err := Up(db, "."); err != nil { - t.Errorf("Failed to run 'up' migrations: %s", err) - } - - ver, err := GetDBVersion(db) - if err != nil { - t.Fatalf("Failed to get migrations version: %s", err) - } - - if ver != 3 { - t.Errorf("Expected version 3 after 'up', got %d", ver) - } - - if err := Reset(db, "."); err != nil { - t.Errorf("Failed to run 'down' migrations: %s", err) - } - - ver, err = GetDBVersion(db) - if err != nil { - t.Fatalf("Failed to get migrations version: %s", err) - } - - if ver != 0 { - t.Errorf("Expected version 0 after 'reset', got %d", ver) - } - }) - - t.Run("Create uses os fs", func(t *testing.T) { - tmpDir := t.TempDir() - - if err := Create(db, tmpDir, "test", "sql"); err != nil { - t.Errorf("Failed to create migration: %s", err) - } - - paths, _ := filepath.Glob(filepath.Join(tmpDir, "*test.sql")) - if len(paths) == 0 { - t.Errorf("Failed to find created migration") - } - - if err := Fix(tmpDir); err != nil { - t.Errorf("Failed to 'fix' migrations: %s", err) - } - - _, err = os.Stat(filepath.Join(tmpDir, "00001_test.sql")) - if err != nil { - t.Errorf("Failed to locate fixed migration: %s", err) - } - }) -} - -func runGoose(params ...string) (string, error) { - dir, err := os.Getwd() - if err != nil { - return "", err - } - cmdPath := filepath.Join(dir, binName) - cmd := exec.Command(cmdPath, params...) - out, err := cmd.CombinedOutput() - if err != nil { - return "", fmt.Errorf("%v\n%v", err, string(out)) - } - return string(out), nil -} - -func countSQLFiles(t *testing.T, dir string) int { - t.Helper() - files, err := filepath.Glob(filepath.Join(dir, "*.sql")) - check.NoError(t, err) - return len(files) -} diff --git a/internal/check/check.go b/internal/check/check.go index 96f242933..76dfac7d6 100644 --- a/internal/check/check.go +++ b/internal/check/check.go @@ -73,7 +73,7 @@ func Bool(t *testing.T, got, want bool) { func Contains(t *testing.T, got, want string) { t.Helper() if !strings.Contains(got, want) { - t.Errorf("failed to find substring %q in string value %q", got, want) + t.Errorf("failed to find substring:\n%s\n\nin string value:\n%s", got, want) } } diff --git a/migrate.go b/migrate.go index b7130a4a0..dc9a1a01e 100644 --- a/migrate.go +++ b/migrate.go @@ -236,17 +236,19 @@ func register( return nil } -func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) { +func collectMigrationsFS( + fsys fs.FS, + dirpath string, + current, target int64, + registered map[int64]*Migration, +) (Migrations, error) { if _, err := fs.Stat(fsys, dirpath); err != nil { if errors.Is(err, fs.ErrNotExist) { return nil, fmt.Errorf("%s directory does not exist", dirpath) } - return nil, err } - var migrations Migrations - // SQL migration files. sqlMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.sql")) if err != nil { @@ -258,68 +260,30 @@ func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Mig return nil, fmt.Errorf("could not parse SQL migration file %q: %w", file, err) } if versionFilter(v, current, target) { - migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file} - migrations = append(migrations, migration) + migrations = append(migrations, &Migration{ + Version: v, + Next: -1, + Previous: -1, + Source: file, + }) } } - - // Go migration files - fsGoMigrations := map[int64]*Migration{} - goMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go")) + // Go migration files. + goMigrations, err := collectGoMigrations(fsys, dirpath, registered, current, target) if err != nil { return nil, err } - for _, file := range goMigrationFiles { - v, err := NumericComponent(file) - if err != nil { - continue // Skip any files that don't have version prefix. - } - - if strings.HasSuffix(file, "_test.go") { - continue // Skip Go test files. - } - - if versionFilter(v, current, target) { - migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false} - fsGoMigrations[v] = migration - } - } - - // Go migrations registered via goose.AddMigration(). - for _, migration := range registeredGoMigrations { - v, err := NumericComponent(migration.Source) - if err != nil { - return nil, fmt.Errorf("could not parse go migration file %q: %w", migration.Source, err) - } - if !versionFilter(v, current, target) { - continue - } - if _, ok := fsGoMigrations[v]; ok { - migrations = append(migrations, migration) - } - } - - for _, fsMigration := range fsGoMigrations { - // Skip migrations already existing migrations registered via goose.AddMigration(). - if _, ok := registeredGoMigrations[fsMigration.Version]; ok { - continue - } - migrations = append(migrations, fsMigration) - } - + migrations = append(migrations, goMigrations...) if len(migrations) == 0 { return nil, ErrNoMigrationFiles } - - migrations = sortAndConnectMigrations(migrations) - - return migrations, nil + return sortAndConnectMigrations(migrations), nil } // CollectMigrations returns all the valid looking migration scripts in the // migrations folder and go func registry, and key them by version. func CollectMigrations(dirpath string, current, target int64) (Migrations, error) { - return collectMigrationsFS(baseFS, dirpath, current, target) + return collectMigrationsFS(baseFS, dirpath, current, target, registeredGoMigrations) } func sortAndConnectMigrations(migrations Migrations) Migrations { @@ -340,15 +304,12 @@ func sortAndConnectMigrations(migrations Migrations) Migrations { } func versionFilter(v, current, target int64) bool { - if target > current { return v > current && v <= target } - if target < current { return v <= current && v > target } - return false } @@ -451,3 +412,97 @@ func withoutContext[T any](fn func(context.Context, T) error) func(T) error { return fn(context.Background(), t) } } + +// collectGoMigrations collects Go migrations from the filesystem and merges them with registered +// migrations. +// +// If Go migrations have been registered globally, with [goose.AddNamedMigration...], but there are +// no corresponding .go files in the filesystem, add them to the migrations slice. +// +// If Go migrations have been registered, and there are .go files in the filesystem dirpath, ONLY +// include those in the migrations slices. +// +// Lastly, if there are .go files in the filesystem but they have not been registered, raise an +// error. This is to prevent users from accidentally adding valid looking Go files to the migrations +// folder without registering them. +func collectGoMigrations( + fsys fs.FS, + dirpath string, + registeredGoMigrations map[int64]*Migration, + current, target int64, +) (Migrations, error) { + // Sanity check registered migrations have the correct version prefix. + for _, m := range registeredGoMigrations { + if _, err := NumericComponent(m.Source); err != nil { + return nil, fmt.Errorf("could not parse go migration file %s: %w", m.Source, err) + } + } + goFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go")) + if err != nil { + return nil, err + } + // If there are no Go files in the filesystem and no registered Go migrations, return early. + if len(goFiles) == 0 && len(registeredGoMigrations) == 0 { + return nil, nil + } + type source struct { + fullpath string + version int64 + } + // Find all Go files that have a version prefix and are within the requested range. + var sources []source + for _, fullpath := range goFiles { + v, err := NumericComponent(fullpath) + if err != nil { + continue // Skip any files that don't have version prefix. + } + if strings.HasSuffix(fullpath, "_test.go") { + continue // Skip Go test files. + } + if versionFilter(v, current, target) { + sources = append(sources, source{ + fullpath: fullpath, + version: v, + }) + } + } + var ( + migrations Migrations + ) + if len(sources) > 0 { + for _, s := range sources { + migration, ok := registeredGoMigrations[s.version] + if ok { + migrations = append(migrations, migration) + } else { + // TODO(mf): something that bothers me about this implementation is it will be + // lazily evaluated and the error will only be raised if the user tries to run the + // migration. It would be better to raise an error much earlier in the process. + migrations = append(migrations, &Migration{ + Version: s.version, + Next: -1, + Previous: -1, + Source: s.fullpath, + Registered: false, + }) + } + } + } else { + // Some users may register Go migrations manually via AddNamedMigration_ functions but not + // provide the corresponding .go files in the filesystem. In this case, we include them + // wholesale in the migrations slice. + // + // This is a valid use case because users may want to build a custom binary that only embeds + // the SQL migration files and some other mechanism for registering Go migrations. + for _, migration := range registeredGoMigrations { + v, err := NumericComponent(migration.Source) + if err != nil { + return nil, fmt.Errorf("could not parse go migration file %s: %w", migration.Source, err) + } + if versionFilter(v, current, target) { + migrations = append(migrations, migration) + } + } + } + return migrations, nil +} diff --git a/migrate_test.go b/migrate_test.go index ac30b4873..9158829ea 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -1,7 +1,13 @@ package goose import ( + "io/fs" + "math" + "os" + "path/filepath" "testing" + + "github.com/pressly/goose/v3/internal/check" ) func TestMigrationSort(t *testing.T) { @@ -56,3 +62,204 @@ func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) { t.Log(ms) } + +func TestCollectMigrations(t *testing.T) { + // Not safe to run in parallel + t.Run("no_migration_files_found", func(t *testing.T) { + tmp := t.TempDir() + err := os.MkdirAll(filepath.Join(tmp, "migrations-test"), 0755) + check.NoError(t, err) + _, err = collectMigrationsFS(os.DirFS(tmp), "migrations-test", 0, math.MaxInt64, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "no migration files found") + }) + t.Run("filesystem_registered_with_single_dirpath", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2 := "09081_a.go", "09082_b.go" + file3, file4 := "19081_a.go", "19082_b.go" + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file2, nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + dir := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir, 0755) + check.NoError(t, err) + createEmptyFile(t, dir, file1) + createEmptyFile(t, dir, file2) + createEmptyFile(t, dir, file3) + createEmptyFile(t, dir, file4) + fsys := os.DirFS(tmp) + files, err := fs.ReadDir(fsys, "migrations/dir1") + check.NoError(t, err) + check.Number(t, len(files), 4) + all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 4) + check.Number(t, all[0].Version, 9081) + check.Number(t, all[1].Version, 9082) + check.Number(t, all[2].Version, 19081) + check.Number(t, all[3].Version, 19082) + }) + t.Run("filesystem_registered_with_multiple_dirpath", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3 := "00001_a.go", "00002_b.go", "01111_c.go" + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file2, nil, nil) + AddNamedMigrationContext(file3, nil, nil) + check.Number(t, len(registeredGoMigrations), 3) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + dir2 := filepath.Join(tmp, "migrations", "dir2") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + err = os.MkdirAll(dir2, 0755) + check.NoError(t, err) + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir2, file3) + fsys := os.DirFS(tmp) + // Validate if dirpath 1 is specified we get the two Go migrations in migrations/dir1 folder + // even though 3 Go migrations have been registered. + { + all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 1) + check.Number(t, all[1].Version, 2) + } + // Validate if dirpath 2 is specified we only get the one Go migration in migrations/dir2 folder + // even though 3 Go migrations have been registered. + { + all, err := collectMigrationsFS(fsys, "migrations/dir2", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 1) + check.Number(t, all[0].Version, 1111) + } + }) + t.Run("empty_filesystem_registered_manually", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + AddNamedMigrationContext("00101_a.go", nil, nil) + AddNamedMigrationContext("00102_b.go", nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + err := os.MkdirAll(filepath.Join(tmp, "migrations"), 0755) + check.NoError(t, err) + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 101) + check.Number(t, all[1].Version, 102) + }) + t.Run("unregistered_go_migrations", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3 := "00001_a.go", "00998_b.go", "00999_c.go" + // Only register file1 and file3, somehow user forgot to init in the + // valid looking file2 Go migration + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file3, nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + // Include the valid file2 with file1, file3. But remember, it has NOT been + // registered. + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir1, file3) + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 3) + check.Number(t, all[0].Version, 1) + check.Bool(t, all[0].Registered, true) + check.Number(t, all[1].Version, 998) + // This migrations is marked unregistered and will lazily raise an error if/when this + // migration is run + check.Bool(t, all[1].Registered, false) + check.Number(t, all[2].Version, 999) + check.Bool(t, all[2].Registered, true) + }) + t.Run("with_skipped_go_files", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3, file4 := "00001_a.go", "00002_b.sql", "00999_c_test.go", "embed.go" + AddNamedMigrationContext(file1, nil, nil) + check.Number(t, len(registeredGoMigrations), 1) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir1, file3) + createEmptyFile(t, dir1, file4) + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 1) + check.Bool(t, all[0].Registered, true) + check.Number(t, all[1].Version, 2) + check.Bool(t, all[1].Registered, false) + }) + t.Run("current_and_target", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3 := "01001_a.go", "01002_b.sql", "01003_c.go" + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file3, nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir1, file3) + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 1001, 1003, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 1002) + check.Number(t, all[1].Version, 1003) + }) +} + +func TestVersionFilter(t *testing.T) { + t.Parallel() + + tests := []struct { + v int64 + current int64 + target int64 + want bool + }{ + {2, 1, 3, true}, // v is within the range + {4, 1, 3, false}, // v is outside the range + {2, 3, 1, true}, // v is within the reversed range + {4, 3, 1, false}, // v is outside the reversed range + {3, 1, 3, true}, // v is equal to target + {1, 1, 3, false}, // v is equal to current, not within the range + {1, 3, 1, false}, // v is equal to current, not within the reversed range + // Always return false if current equal target + {1, 2, 2, false}, + {2, 2, 2, false}, + {3, 2, 2, false}, + } + for _, tc := range tests { + t.Run("", func(t *testing.T) { + got := versionFilter(tc.v, tc.current, tc.target) + if got != tc.want { + t.Errorf("versionFilter(%d, %d, %d) = %v, want %v", tc.v, tc.current, tc.target, got, tc.want) + } + }) + } +} +func createEmptyFile(t *testing.T, dir, name string) { + path := filepath.Join(dir, name) + f, err := os.Create(path) + check.NoError(t, err) + defer f.Close() +} + +func clearMap(m map[int64]*Migration) { + for k := range m { + delete(m, k) + } +} diff --git a/testdata/migrations/00001_users_table.sql b/testdata/migrations/00001_users_table.sql new file mode 100644 index 000000000..3d74ed4af --- /dev/null +++ b/testdata/migrations/00001_users_table.sql @@ -0,0 +1,10 @@ +-- +goose Up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL, + email TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- +goose Down +DROP TABLE users; diff --git a/testdata/migrations/00002_posts_table.sql b/testdata/migrations/00002_posts_table.sql new file mode 100644 index 000000000..25648ed42 --- /dev/null +++ b/testdata/migrations/00002_posts_table.sql @@ -0,0 +1,12 @@ +-- +goose Up +CREATE TABLE posts ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + content TEXT NOT NULL, + author_id INTEGER NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (author_id) REFERENCES users(id) +); + +-- +goose Down +DROP TABLE posts; diff --git a/testdata/migrations/00003_comments_table.sql b/testdata/migrations/00003_comments_table.sql new file mode 100644 index 000000000..319776395 --- /dev/null +++ b/testdata/migrations/00003_comments_table.sql @@ -0,0 +1,13 @@ +-- +goose Up +CREATE TABLE comments ( + id INTEGER PRIMARY KEY, + post_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id), + FOREIGN KEY (user_id) REFERENCES users(id) +); + +-- +goose Down +DROP TABLE comments; diff --git a/testdata/migrations/00004_insert_data.sql b/testdata/migrations/00004_insert_data.sql new file mode 100644 index 000000000..262dd04b2 --- /dev/null +++ b/testdata/migrations/00004_insert_data.sql @@ -0,0 +1,23 @@ +-- +goose Up +INSERT INTO users (id, username, email) +VALUES + (1, 'john_doe', 'john@example.com'), + (2, 'jane_smith', 'jane@example.com'), + (3, 'alice_wonderland', 'alice@example.com'); + +INSERT INTO posts (id, title, content, author_id) +VALUES + (1, 'Introduction to SQL', 'SQL is a powerful language for managing databases...', 1), + (2, 'Data Modeling Techniques', 'Choosing the right data model is crucial...', 2), + (3, 'Advanced Query Optimization', 'Optimizing queries can greatly improve...', 1); + +INSERT INTO comments (id, post_id, user_id, content) +VALUES + (1, 1, 3, 'Great introduction! Looking forward to more.'), + (2, 1, 2, 'SQL can be a bit tricky at first, but practice helps.'), + (3, 2, 1, 'You covered normalization really well in this post.'); + +-- +goose Down +DELETE FROM comments; +DELETE FROM posts; +DELETE FROM users; diff --git a/testdata/migrations/00005_posts_view.sql b/testdata/migrations/00005_posts_view.sql new file mode 100644 index 000000000..a3763c948 --- /dev/null +++ b/testdata/migrations/00005_posts_view.sql @@ -0,0 +1,15 @@ +-- +goose NO TRANSACTION + +-- +goose Up +CREATE VIEW posts_view AS + SELECT + p.id, + p.title, + p.content, + p.created_at, + u.username AS author + FROM posts p + JOIN users u ON p.author_id = u.id; + +-- +goose Down +DROP VIEW posts_view;