From 60fb891567d0363dbd71ec9659e108af94d3762d Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Fri, 25 Aug 2023 20:11:15 -0400 Subject: [PATCH 1/6] fix: collect go migrations --- create.go | 2 +- fix.go | 2 +- migrate.go | 164 ++++++++++++++++++++++++++++++++---------------- migrate_test.go | 122 +++++++++++++++++++++++++++++++++++ 4 files changed, 235 insertions(+), 55 deletions(-) diff --git a/create.go b/create.go index d9ec002f3..db47a2e4c 100644 --- a/create.go +++ b/create.go @@ -30,7 +30,7 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m if sequential { // always use DirFS here because it's modifying operation - migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) + migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion, registeredGoMigrations) if err != nil && !errors.Is(err, ErrNoMigrationFiles) { return err } diff --git a/fix.go b/fix.go index 7bc7ed5d6..a498358ee 100644 --- a/fix.go +++ b/fix.go @@ -11,7 +11,7 @@ const seqVersionTemplate = "%05v" func Fix(dir string) error { // always use osFS here because it's modifying operation - migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion) + migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion, registeredGoMigrations) if err != nil { return err } diff --git a/migrate.go b/migrate.go index b7130a4a0..b890dc0ae 100644 --- a/migrate.go +++ b/migrate.go @@ -236,17 +236,19 @@ func register( return nil } -func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) { +func collectMigrationsFS( + fsys fs.FS, + dirpath string, + current, target int64, + registered map[int64]*Migration, +) (Migrations, error) { if _, err := fs.Stat(fsys, dirpath); err != nil { if errors.Is(err, fs.ErrNotExist) { return nil, fmt.Errorf("%s directory does not exist", dirpath) } - return nil, err } - var migrations Migrations - // SQL migration files. sqlMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.sql")) if err != nil { @@ -258,68 +260,30 @@ func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Mig return nil, fmt.Errorf("could not parse SQL migration file %q: %w", file, err) } if versionFilter(v, current, target) { - migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file} - migrations = append(migrations, migration) + migrations = append(migrations, &Migration{ + Version: v, + Next: -1, + Previous: -1, + Source: file, + }) } } - - // Go migration files - fsGoMigrations := map[int64]*Migration{} - goMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go")) + // Go migration files. + goMigrations, err := collectGoMigrations(fsys, dirpath, registered, current, target) if err != nil { return nil, err } - for _, file := range goMigrationFiles { - v, err := NumericComponent(file) - if err != nil { - continue // Skip any files that don't have version prefix. - } - - if strings.HasSuffix(file, "_test.go") { - continue // Skip Go test files. - } - - if versionFilter(v, current, target) { - migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false} - fsGoMigrations[v] = migration - } - } - - // Go migrations registered via goose.AddMigration(). - for _, migration := range registeredGoMigrations { - v, err := NumericComponent(migration.Source) - if err != nil { - return nil, fmt.Errorf("could not parse go migration file %q: %w", migration.Source, err) - } - if !versionFilter(v, current, target) { - continue - } - if _, ok := fsGoMigrations[v]; ok { - migrations = append(migrations, migration) - } - } - - for _, fsMigration := range fsGoMigrations { - // Skip migrations already existing migrations registered via goose.AddMigration(). - if _, ok := registeredGoMigrations[fsMigration.Version]; ok { - continue - } - migrations = append(migrations, fsMigration) - } - + migrations = append(migrations, goMigrations...) if len(migrations) == 0 { return nil, ErrNoMigrationFiles } - - migrations = sortAndConnectMigrations(migrations) - - return migrations, nil + return sortAndConnectMigrations(migrations), nil } // CollectMigrations returns all the valid looking migration scripts in the // migrations folder and go func registry, and key them by version. func CollectMigrations(dirpath string, current, target int64) (Migrations, error) { - return collectMigrationsFS(baseFS, dirpath, current, target) + return collectMigrationsFS(baseFS, dirpath, current, target, registeredGoMigrations) } func sortAndConnectMigrations(migrations Migrations) Migrations { @@ -451,3 +415,97 @@ func withoutContext[T any](fn func(context.Context, T) error) func(T) error { return fn(context.Background(), t) } } + +// collectGoMigrations collects Go migrations from the filesystem and merges them with registered +// migrations. +// +// If Go migrations have been registered globally, with [goose.AddNamedMigration...], but there are +// no corresponding .go files in the filesystem, add them to the migrations slice. +// +// If Go migrations have been registered, and there are .go files in the filesystem dirpath, ONLY +// include those in the migrations slices. +// +// Lastly, if there are .go files in the filesystem but they have not been registered, raise an +// error. This is to prevent users from accidentally adding valid looking Go files to the migrations +// folder without registering them. +func collectGoMigrations( + fsys fs.FS, + dirpath string, + registeredGoMigrations map[int64]*Migration, + current, target int64, +) (Migrations, error) { + // Sanity check registered migrations have the correct version prefix. + for _, m := range registeredGoMigrations { + if _, err := NumericComponent(m.Source); err != nil { + return nil, fmt.Errorf("could not parse go migration file %s: %w", m.Source, err) + } + } + goFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go")) + if err != nil { + return nil, err + } + // If there are no Go files in the filesystem and no registered Go migrations, return early. + if len(goFiles) == 0 && len(registeredGoMigrations) == 0 { + return nil, nil + } + type source struct { + fullpath string + version int64 + } + // Find all Go files that have a version prefix and are within the requested range. + var sources []source + for _, fullpath := range goFiles { + v, err := NumericComponent(fullpath) + if err != nil { + continue // Skip any files that don't have version prefix. + } + if strings.HasSuffix(fullpath, "_test.go") { + continue // Skip Go test files. + } + if versionFilter(v, current, target) { + sources = append(sources, source{ + fullpath: fullpath, + version: v, + }) + } + } + var ( + migrations Migrations + ) + if len(sources) > 0 { + for _, s := range sources { + migration, ok := registeredGoMigrations[s.version] + if ok { + migrations = append(migrations, migration) + } else { + // TODO(mf): something that bothers me about this implementation is it will be + // lazily evaluated and the error will only be raised if the user tries to run the + // migration. It would be better to raise an error much earlier in the process. + migrations = append(migrations, &Migration{ + Version: s.version, + Next: -1, + Previous: -1, + Source: s.fullpath, + Registered: false, + }) + } + } + } else { + // Some users may register Go migrations manually via AddNamedMigration_ functions but not + // provide the corresponding .go files in the filesystem. In this case, we include them + // wholesale in the migrations slice. + // + // This is a valid use case because users may want to build a custom binary that only embeds + // the SQL migration files and some other mechanism for registering Go migrations. + for _, migration := range registeredGoMigrations { + v, err := NumericComponent(migration.Source) + if err != nil { + return nil, fmt.Errorf("could not parse go migration file %s: %w", migration.Source, err) + } + if versionFilter(v, current, target) { + migrations = append(migrations, migration) + } + } + } + return migrations, nil +} diff --git a/migrate_test.go b/migrate_test.go index ac30b4873..eaa3acaa6 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -1,7 +1,12 @@ package goose import ( + "math" + "os" + "path/filepath" "testing" + + "github.com/pressly/goose/v3/internal/check" ) func TestMigrationSort(t *testing.T) { @@ -56,3 +61,120 @@ func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) { t.Log(ms) } + +func TestCollectMigrations(t *testing.T) { + // Not safe to run in parallel + t.Run("no_migration_files_found", func(t *testing.T) { + tmp := t.TempDir() + err := os.MkdirAll(filepath.Join(tmp, "migrations-test"), 0755) + check.NoError(t, err) + _, err = collectMigrationsFS(os.DirFS(tmp), "migrations-test", 0, math.MaxInt64, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "no migration files found") + }) + t.Run("filesystem_registered_with_single_dirpath", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2 := "09081_a.go", "09082_b.go" + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file2, nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + dir := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir, 0755) + check.NoError(t, err) + createEmptyFile(t, dir, file1) + createEmptyFile(t, dir, file2) + fsys := os.DirFS(tmp) + all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 9081) + check.Number(t, all[1].Version, 9082) + }) + t.Run("filesystem_registered_with_multiple_dirpath", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3 := "00001_a.go", "00002_b.go", "01111_c.go" + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file2, nil, nil) + AddNamedMigrationContext(file3, nil, nil) + check.Number(t, len(registeredGoMigrations), 3) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + dir2 := filepath.Join(tmp, "migrations", "dir2") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + err = os.MkdirAll(dir2, 0755) + check.NoError(t, err) + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir2, file3) + fsys := os.DirFS(tmp) + // Validate if dirpath 1 is specified we get the two Go migrations in migrations/dir1 folder + // even though 3 Go migrations have been registered. + { + all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 1) + check.Number(t, all[1].Version, 2) + } + // Validate if dirpath 2 is specified we only get the one Go migration in migrations/dir2 folder + // even though 3 Go migrations have been registered. + { + all, err := collectMigrationsFS(fsys, "migrations/dir2", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 1) + check.Number(t, all[0].Version, 1111) + } + }) + t.Run("empty_filesystem_registered_manually", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + AddNamedMigrationContext("00101_a.go", nil, nil) + AddNamedMigrationContext("00102_b.go", nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + err := os.MkdirAll(filepath.Join(tmp, "migrations"), 0755) + check.NoError(t, err) + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 101) + check.Number(t, all[1].Version, 102) + }) + t.Run("unregistered_go_migrations", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3 := "00001_a.go", "00998_b.go", "00999_c.go" + // Only register file1 and file3, somehow user forgot to init in the + // valid looking file2 Go migration + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file3, nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + // Include the valid file2 with file1, file3. But remember, it has NOT been + // registered. + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir1, file3) + _, err = collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.HasError(t, err) + check.Contains(t, err.Error(), "detected 1 unregistered Go file") + check.Contains(t, err.Error(), "migrations/dir1/00998_b.go") + check.Contains(t, err.Error(), "go functions must be registered and built into a custom binary") + }) +} + +func createEmptyFile(t *testing.T, dir, name string) { + path := filepath.Join(dir, name) + f, err := os.Create(path) + check.NoError(t, err) + defer f.Close() +} + +func clearMap(m map[int64]*Migration) { + for k := range m { + delete(m, k) + } +} From 54669b01cd54766f8c5f07f64e64173d3a9257c7 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Fri, 25 Aug 2023 20:25:31 -0400 Subject: [PATCH 2/6] update unregistered test --- migrate_test.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/migrate_test.go b/migrate_test.go index eaa3acaa6..9270832ea 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -158,11 +158,15 @@ func TestCollectMigrations(t *testing.T) { createEmptyFile(t, dir1, file1) createEmptyFile(t, dir1, file2) createEmptyFile(t, dir1, file3) - _, err = collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) - check.HasError(t, err) - check.Contains(t, err.Error(), "detected 1 unregistered Go file") - check.Contains(t, err.Error(), "migrations/dir1/00998_b.go") - check.Contains(t, err.Error(), "go functions must be registered and built into a custom binary") + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 3) + check.Number(t, all[0].Version, 1) + check.Bool(t, all[0].Registered, true) + check.Number(t, all[1].Version, 998) + check.Bool(t, all[1].Registered, false) + check.Number(t, all[2].Version, 999) + check.Bool(t, all[2].Registered, true) }) } From d57d3ad386798dd35e0e99f9c1e10e8b16251c60 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Fri, 25 Aug 2023 20:40:40 -0400 Subject: [PATCH 3/6] additional tests --- migrate_test.go | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/migrate_test.go b/migrate_test.go index 9270832ea..0e0f93d5c 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -1,6 +1,7 @@ package goose import ( + "io/fs" "math" "os" "path/filepath" @@ -75,6 +76,7 @@ func TestCollectMigrations(t *testing.T) { t.Run("filesystem_registered_with_single_dirpath", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) file1, file2 := "09081_a.go", "09082_b.go" + file3, file4 := "19081_a.go", "19082_b.go" AddNamedMigrationContext(file1, nil, nil) AddNamedMigrationContext(file2, nil, nil) check.Number(t, len(registeredGoMigrations), 2) @@ -84,12 +86,19 @@ func TestCollectMigrations(t *testing.T) { check.NoError(t, err) createEmptyFile(t, dir, file1) createEmptyFile(t, dir, file2) + createEmptyFile(t, dir, file3) + createEmptyFile(t, dir, file4) fsys := os.DirFS(tmp) + files, err := fs.ReadDir(fsys, "migrations/dir1") + check.NoError(t, err) + check.Number(t, len(files), 4) all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) check.NoError(t, err) - check.Number(t, len(all), 2) + check.Number(t, len(all), 4) check.Number(t, all[0].Version, 9081) check.Number(t, all[1].Version, 9082) + check.Number(t, all[2].Version, 19081) + check.Number(t, all[3].Version, 19082) }) t.Run("filesystem_registered_with_multiple_dirpath", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) @@ -168,6 +177,27 @@ func TestCollectMigrations(t *testing.T) { check.Number(t, all[2].Version, 999) check.Bool(t, all[2].Registered, true) }) + t.Run("with_skipped_go_files", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3, file4 := "00001_a.go", "00002_b.sql", "00999_c_test.go", "embed.go" + AddNamedMigrationContext(file1, nil, nil) + check.Number(t, len(registeredGoMigrations), 1) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir1, file3) + createEmptyFile(t, dir1, file4) + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 1) + check.Bool(t, all[0].Registered, true) + check.Number(t, all[1].Version, 2) + check.Bool(t, all[1].Registered, false) + }) } func createEmptyFile(t *testing.T, dir, name string) { From f64d26f5248f6e271c4900f7e81e1374c26da4e4 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 26 Aug 2023 12:19:49 -0400 Subject: [PATCH 4/6] add version filter tests --- migrate.go | 3 --- migrate_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/migrate.go b/migrate.go index b890dc0ae..dc9a1a01e 100644 --- a/migrate.go +++ b/migrate.go @@ -304,15 +304,12 @@ func sortAndConnectMigrations(migrations Migrations) Migrations { } func versionFilter(v, current, target int64) bool { - if target > current { return v > current && v <= target } - if target < current { return v <= current && v > target } - return false } diff --git a/migrate_test.go b/migrate_test.go index 0e0f93d5c..3196c60c0 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -173,6 +173,8 @@ func TestCollectMigrations(t *testing.T) { check.Number(t, all[0].Version, 1) check.Bool(t, all[0].Registered, true) check.Number(t, all[1].Version, 998) + // This migrations is marked unregistered and will lazily raise an error if/when this + // migration is run check.Bool(t, all[1].Registered, false) check.Number(t, all[2].Version, 999) check.Bool(t, all[2].Registered, true) @@ -200,6 +202,34 @@ func TestCollectMigrations(t *testing.T) { }) } +func TestVersionFilter(t *testing.T) { + tests := []struct { + v int64 + current int64 + target int64 + want bool + }{ + {2, 1, 3, true}, // v is within the range + {4, 1, 3, false}, // v is outside the range + {2, 3, 1, true}, // v is within the reversed range + {4, 3, 1, false}, // v is outside the reversed range + {3, 1, 3, true}, // v is equal to target + {1, 1, 3, false}, // v is equal to current, not within the range + {1, 3, 1, false}, // v is equal to current, not within the reversed range + // Always return false if current equal target + {1, 2, 2, false}, + {2, 2, 2, false}, + {3, 2, 2, false}, + } + for _, tc := range tests { + t.Run("", func(t *testing.T) { + got := versionFilter(tc.v, tc.current, tc.target) + if got != tc.want { + t.Errorf("versionFilter(%d, %d, %d) = %v, want %v", tc.v, tc.current, tc.target, got, tc.want) + } + }) + } +} func createEmptyFile(t *testing.T, dir, name string) { path := filepath.Join(dir, name) f, err := os.Create(path) From 0148a2797120ea707bf72f6a9a064a89c3783329 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 26 Aug 2023 12:25:17 -0400 Subject: [PATCH 5/6] safe to run parallel --- migrate_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/migrate_test.go b/migrate_test.go index 3196c60c0..8ebfa5bbc 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -203,6 +203,8 @@ func TestCollectMigrations(t *testing.T) { } func TestVersionFilter(t *testing.T) { + t.Parallel() + tests := []struct { v int64 current int64 From dd85eb5c678c6a6ade27fe84b8912b5846e5122b Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sun, 27 Aug 2023 11:17:18 -0400 Subject: [PATCH 6/6] add test for current and target --- migrate_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/migrate_test.go b/migrate_test.go index 8ebfa5bbc..9158829ea 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -200,6 +200,25 @@ func TestCollectMigrations(t *testing.T) { check.Number(t, all[1].Version, 2) check.Bool(t, all[1].Registered, false) }) + t.Run("current_and_target", func(t *testing.T) { + t.Cleanup(func() { clearMap(registeredGoMigrations) }) + file1, file2, file3 := "01001_a.go", "01002_b.sql", "01003_c.go" + AddNamedMigrationContext(file1, nil, nil) + AddNamedMigrationContext(file3, nil, nil) + check.Number(t, len(registeredGoMigrations), 2) + tmp := t.TempDir() + dir1 := filepath.Join(tmp, "migrations", "dir1") + err := os.MkdirAll(dir1, 0755) + check.NoError(t, err) + createEmptyFile(t, dir1, file1) + createEmptyFile(t, dir1, file2) + createEmptyFile(t, dir1, file3) + all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 1001, 1003, registeredGoMigrations) + check.NoError(t, err) + check.Number(t, len(all), 2) + check.Number(t, all[0].Version, 1002) + check.Number(t, all[1].Version, 1003) + }) } func TestVersionFilter(t *testing.T) {