From 77dec20af7cc943bcb193d329056cd90f46c4a67 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 16 Sep 2023 08:38:56 -0400 Subject: [PATCH] add sqladapter package with tests --- internal/sqladapter/sqladapter.go | 56 ++++++++++ internal/sqladapter/store.go | 111 ++++++++++++++++++++ internal/sqladapter/store_test.go | 167 ++++++++++++++++++++++++++++++ provider.go | 18 ++-- 4 files changed, 346 insertions(+), 6 deletions(-) create mode 100644 internal/sqladapter/sqladapter.go create mode 100644 internal/sqladapter/store.go create mode 100644 internal/sqladapter/store_test.go diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go new file mode 100644 index 000000000..1ef1c42b7 --- /dev/null +++ b/internal/sqladapter/sqladapter.go @@ -0,0 +1,56 @@ +// Package sqladapter provides an interface for interacting with a SQL database. +// +// All supported database dialects must implement the Store interface. +package sqladapter + +import ( + "context" + "database/sql" + "time" +) + +// DBTxConn is an interface that is satisfied by *sql.DB, *sql.Tx and *sql.Conn. +// +// There is a long outstanding issue to formalize a std lib interface, but alas... +// See: https://github.com/golang/go/issues/14468 +type DBTxConn interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// Store is the interface that wraps the basic methods for a database dialect. +// +// A dialect is a set of SQL statements that are specific to a database. +// +// By defining a store interface, we can support multiple databases with a single codebase. +// +// The underlying implementation does not modify the error. It is the callers responsibility to +// assert for the correct error, such as sql.ErrNoRows. +type Store interface { + // CreateVersionTable creates the version table within a transaction. This table is used to + // record applied migrations. + CreateVersionTable(ctx context.Context, tx *sql.Tx, tablename string) error + + // InsertOrDelete inserts or deletes a version id from the version table. + InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error + + // GetMigration retrieves a single migration by version id. + // + // Returns the raw sql error if the query fails. It is the callers responsibility + // to assert for the correct error, such as sql.ErrNoRows. + GetMigrationConn(ctx context.Context, conn *sql.Conn, version int64) (*GetMigrationResult, error) + + // ListMigrations retrieves all migrations sorted in descending order by id. + // + // If there are no migrations, an empty slice is returned with no error. + ListMigrationsConn(ctx context.Context, conn *sql.Conn) ([]*ListMigrationsResult, error) +} + +type GetMigrationResult struct { + IsApplied bool + Timestamp time.Time +} + +type ListMigrationsResult struct { + Version int64 + IsApplied bool +} diff --git a/internal/sqladapter/store.go b/internal/sqladapter/store.go new file mode 100644 index 000000000..a61e5aa60 --- /dev/null +++ b/internal/sqladapter/store.go @@ -0,0 +1,111 @@ +package sqladapter + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/dialect/dialectquery" +) + +var _ Store = (*store)(nil) + +type store struct { + tablename string + querier dialectquery.Querier +} + +// NewStore returns a new Store backed by the given dialect. +// +// The dialect must match one of the supported dialects defined in dialect.go. +func NewStore(dialect string, table string) (Store, error) { + if table == "" { + return nil, errors.New("table must not be empty") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + var querier dialectquery.Querier + switch dialect { + case "clickhouse": + querier = &dialectquery.Clickhouse{} + case "mssql": + querier = &dialectquery.Sqlserver{} + case "mysql": + querier = &dialectquery.Mysql{} + case "postgres": + querier = &dialectquery.Postgres{} + case "redshift": + querier = &dialectquery.Redshift{} + case "sqlite3": + querier = &dialectquery.Sqlite3{} + case "tidb": + querier = &dialectquery.Tidb{} + case "vertica": + querier = &dialectquery.Vertica{} + default: + return nil, fmt.Errorf("unknown dialect: %q", dialect) + } + return &store{ + tablename: table, + querier: querier, + }, nil +} + +func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tablename string) error { + q := s.querier.CreateTable(s.tablename) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("failed to create version table %q: %w", tablename, err) + } + return nil +} + +func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error { + if direction { + q := s.querier.InsertVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version, true); err != nil { + return fmt.Errorf("failed to insert version %d: %w", version, err) + } + return nil + } + q := s.querier.DeleteVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version); err != nil { + return fmt.Errorf("failed to delete version %d: %w", version, err) + } + return nil +} + +func (s *store) GetMigrationConn(ctx context.Context, conn *sql.Conn, version int64) (*GetMigrationResult, error) { + q := s.querier.GetMigrationByVersion(s.tablename) + result := new(GetMigrationResult) + if err := conn.QueryRowContext(ctx, q, version).Scan( + &result.Timestamp, + &result.IsApplied, + ); err != nil { + return nil, fmt.Errorf("failed to get migration %d: %w", version, err) + } + return result, nil +} + +func (s *store) ListMigrationsConn(ctx context.Context, conn *sql.Conn) ([]*ListMigrationsResult, error) { + q := s.querier.ListMigrations(s.tablename) + rows, err := conn.QueryContext(ctx, q) + if err != nil { + return nil, fmt.Errorf("failed to list migrations: %w", err) + } + defer rows.Close() + + var migrations []*ListMigrationsResult + for rows.Next() { + result := new(ListMigrationsResult) + if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { + return nil, fmt.Errorf("failed to scan list migrations result: %w", err) + } + migrations = append(migrations, result) + } + if err := rows.Err(); err != nil { + return nil, err + } + return migrations, nil +} diff --git a/internal/sqladapter/store_test.go b/internal/sqladapter/store_test.go new file mode 100644 index 000000000..8c69e405b --- /dev/null +++ b/internal/sqladapter/store_test.go @@ -0,0 +1,167 @@ +package sqladapter_test + +import ( + "context" + "database/sql" + "errors" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/testdb" +) + +// The goal of this test is to verify the sqladapter package works as expected. This test is not +// meant to be exhaustive or test every possible database dialect. It is meant to verify that the +// Store interface works against a real database. + +func TestStore_Postgres(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skip long-running test") + } + ctx := context.Background() + const ( + tablename = "test_goose_db_version" + ) + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + store, err := sqladapter.NewStore(string(goose.DialectPostgres), tablename) + check.NoError(t, err) + // Create the version table. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx, tablename) + }) + // Create the version table again. This should fail. + check.NoError(t, err) + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx, tablename) + }) + check.HasError(t, err) + var pgErr *pgconn.PgError + ok := errors.As(err, &pgErr) + check.Bool(t, ok, true) + check.Equal(t, pgErr.Code, "42P07") // duplicate_table + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrationsConn(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + // Insert 5 migrations in addition to the zero migration. + for i := 0; i < 6; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, true, int64(i)) + }) + check.NoError(t, err) + } + // List migrations. There should be 6. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrationsConn(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 6) + // Check versions are in descending order. + for i := 0; i < 6; i++ { + check.Number(t, res[i].Version, 5-i) + } + return nil + }) + check.NoError(t, err) + // Delete 3 migrations backwards + for i := 5; i >= 3; i-- { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, int64(i)) + }) + check.NoError(t, err) + } + // List migrations. There should be 3. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrationsConn(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 3) + // Check that the remaining versions are in descending order. + for i := 0; i < 3; i++ { + check.Number(t, res[i].Version, 2-i) + } + return nil + }) + check.NoError(t, err) + // Get remaining migrations one by one. + for i := 0; i < 3; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.GetMigrationConn(ctx, conn, int64(i)) + check.NoError(t, err) + check.Equal(t, res.IsApplied, true) + check.Equal(t, res.Timestamp.IsZero(), false) + return nil + }) + check.NoError(t, err) + } + // Delete remaining migrations one by one and use all 3 connection types: + // *sql.DB + // *sql.Tx + // *sql.Conn. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.InsertOrDelete(ctx, tx, false, 2) // *sql.Tx + }) + check.NoError(t, err) + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, 1) // *sql.Conn + }) + check.NoError(t, err) + err = store.InsertOrDelete(ctx, db, false, 0) // *sql.DB + check.NoError(t, err) + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrationsConn(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + // Try to get a migration that does not exist. + err = runConn(ctx, db, func(conn *sql.Conn) error { + _, err := store.GetMigrationConn(ctx, conn, 0) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrNoRows), true) + return nil + }) + check.NoError(t, err) +} + +func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = errors.Join(retErr, tx.Rollback()) + } + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = errors.Join(retErr, conn.Close()) + } + }() + if err := fn(conn); err != nil { + return err + } + return conn.Close() +} diff --git a/provider.go b/provider.go index c8d30112f..4798895a0 100644 --- a/provider.go +++ b/provider.go @@ -5,12 +5,15 @@ import ( "database/sql" "errors" "time" + + "github.com/pressly/goose/v3/internal/sqladapter" ) // Provider is a goose migration provider. type Provider struct { - db *sql.DB - opt *ProviderOptions + db *sql.DB + store sqladapter.Store + opt *ProviderOptions } // NewProvider returns a new goose Provider. @@ -35,14 +38,17 @@ func NewProvider(dialect Dialect, db *sql.DB, opts *ProviderOptions) (*Provider, if err := validateOptions(opts); err != nil { return nil, err } - // + store, err := sqladapter.NewStore(string(dialect), opts.Tablename) + if err != nil { + return nil, err + } // TODO(mf): implement the rest of this function - // - db / dialect store // - collect sources // - merge sources into migrations return &Provider{ - db: db, - opt: opts, + db: db, + store: store, + opt: opts, }, nil }