-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
346 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Package sqladapter provides an interface for interacting with a SQL database. | ||
// | ||
// All supported database dialects must implement the Store interface. | ||
package sqladapter | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"time" | ||
) | ||
|
||
// DBTxConn is an interface that is satisfied by *sql.DB, *sql.Tx and *sql.Conn. | ||
// | ||
// There is a long outstanding issue to formalize a std lib interface, but alas... | ||
// See: https://github.com/golang/go/issues/14468 | ||
type DBTxConn interface { | ||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||
} | ||
|
||
// Store is the interface that wraps the basic methods for a database dialect. | ||
// | ||
// A dialect is a set of SQL statements that are specific to a database. | ||
// | ||
// By defining a store interface, we can support multiple databases with a single codebase. | ||
// | ||
// The underlying implementation does not modify the error. It is the callers responsibility to | ||
// assert for the correct error, such as sql.ErrNoRows. | ||
type Store interface { | ||
// CreateVersionTable creates the version table within a transaction. This table is used to | ||
// record applied migrations. | ||
CreateVersionTable(ctx context.Context, tx *sql.Tx, tablename string) error | ||
|
||
// InsertOrDelete inserts or deletes a version id from the version table. | ||
InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error | ||
|
||
// GetMigration retrieves a single migration by version id. | ||
// | ||
// Returns the raw sql error if the query fails. It is the callers responsibility | ||
// to assert for the correct error, such as sql.ErrNoRows. | ||
GetMigrationConn(ctx context.Context, conn *sql.Conn, version int64) (*GetMigrationResult, error) | ||
|
||
// ListMigrations retrieves all migrations sorted in descending order by id. | ||
// | ||
// If there are no migrations, an empty slice is returned with no error. | ||
ListMigrationsConn(ctx context.Context, conn *sql.Conn) ([]*ListMigrationsResult, error) | ||
} | ||
|
||
type GetMigrationResult struct { | ||
IsApplied bool | ||
Timestamp time.Time | ||
} | ||
|
||
type ListMigrationsResult struct { | ||
Version int64 | ||
IsApplied bool | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
package sqladapter | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/pressly/goose/v3/internal/dialect/dialectquery" | ||
) | ||
|
||
var _ Store = (*store)(nil) | ||
|
||
type store struct { | ||
tablename string | ||
querier dialectquery.Querier | ||
} | ||
|
||
// NewStore returns a new Store backed by the given dialect. | ||
// | ||
// The dialect must match one of the supported dialects defined in dialect.go. | ||
func NewStore(dialect string, table string) (Store, error) { | ||
if table == "" { | ||
return nil, errors.New("table must not be empty") | ||
} | ||
if dialect == "" { | ||
return nil, errors.New("dialect must not be empty") | ||
} | ||
var querier dialectquery.Querier | ||
switch dialect { | ||
case "clickhouse": | ||
querier = &dialectquery.Clickhouse{} | ||
case "mssql": | ||
querier = &dialectquery.Sqlserver{} | ||
case "mysql": | ||
querier = &dialectquery.Mysql{} | ||
case "postgres": | ||
querier = &dialectquery.Postgres{} | ||
case "redshift": | ||
querier = &dialectquery.Redshift{} | ||
case "sqlite3": | ||
querier = &dialectquery.Sqlite3{} | ||
case "tidb": | ||
querier = &dialectquery.Tidb{} | ||
case "vertica": | ||
querier = &dialectquery.Vertica{} | ||
default: | ||
return nil, fmt.Errorf("unknown dialect: %q", dialect) | ||
} | ||
return &store{ | ||
tablename: table, | ||
querier: querier, | ||
}, nil | ||
} | ||
|
||
func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tablename string) error { | ||
q := s.querier.CreateTable(s.tablename) | ||
if _, err := tx.ExecContext(ctx, q); err != nil { | ||
return fmt.Errorf("failed to create version table %q: %w", tablename, err) | ||
} | ||
return nil | ||
} | ||
|
||
func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error { | ||
if direction { | ||
q := s.querier.InsertVersion(s.tablename) | ||
if _, err := db.ExecContext(ctx, q, version, true); err != nil { | ||
return fmt.Errorf("failed to insert version %d: %w", version, err) | ||
} | ||
return nil | ||
} | ||
q := s.querier.DeleteVersion(s.tablename) | ||
if _, err := db.ExecContext(ctx, q, version); err != nil { | ||
return fmt.Errorf("failed to delete version %d: %w", version, err) | ||
} | ||
return nil | ||
} | ||
|
||
func (s *store) GetMigrationConn(ctx context.Context, conn *sql.Conn, version int64) (*GetMigrationResult, error) { | ||
q := s.querier.GetMigrationByVersion(s.tablename) | ||
result := new(GetMigrationResult) | ||
if err := conn.QueryRowContext(ctx, q, version).Scan( | ||
&result.Timestamp, | ||
&result.IsApplied, | ||
); err != nil { | ||
return nil, fmt.Errorf("failed to get migration %d: %w", version, err) | ||
} | ||
return result, nil | ||
} | ||
|
||
func (s *store) ListMigrationsConn(ctx context.Context, conn *sql.Conn) ([]*ListMigrationsResult, error) { | ||
q := s.querier.ListMigrations(s.tablename) | ||
rows, err := conn.QueryContext(ctx, q) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to list migrations: %w", err) | ||
} | ||
defer rows.Close() | ||
|
||
var migrations []*ListMigrationsResult | ||
for rows.Next() { | ||
result := new(ListMigrationsResult) | ||
if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { | ||
return nil, fmt.Errorf("failed to scan list migrations result: %w", err) | ||
} | ||
migrations = append(migrations, result) | ||
} | ||
if err := rows.Err(); err != nil { | ||
return nil, err | ||
} | ||
return migrations, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
package sqladapter_test | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"errors" | ||
"testing" | ||
|
||
"github.com/jackc/pgx/v5/pgconn" | ||
"github.com/pressly/goose/v3" | ||
"github.com/pressly/goose/v3/internal/check" | ||
"github.com/pressly/goose/v3/internal/sqladapter" | ||
"github.com/pressly/goose/v3/internal/testdb" | ||
) | ||
|
||
// The goal of this test is to verify the sqladapter package works as expected. This test is not | ||
// meant to be exhaustive or test every possible database dialect. It is meant to verify that the | ||
// Store interface works against a real database. | ||
|
||
func TestStore_Postgres(t *testing.T) { | ||
t.Parallel() | ||
if testing.Short() { | ||
t.Skip("skip long-running test") | ||
} | ||
ctx := context.Background() | ||
const ( | ||
tablename = "test_goose_db_version" | ||
) | ||
db, cleanup, err := testdb.NewPostgres() | ||
check.NoError(t, err) | ||
t.Cleanup(cleanup) | ||
store, err := sqladapter.NewStore(string(goose.DialectPostgres), tablename) | ||
check.NoError(t, err) | ||
// Create the version table. | ||
err = runTx(ctx, db, func(tx *sql.Tx) error { | ||
return store.CreateVersionTable(ctx, tx, tablename) | ||
}) | ||
// Create the version table again. This should fail. | ||
check.NoError(t, err) | ||
err = runTx(ctx, db, func(tx *sql.Tx) error { | ||
return store.CreateVersionTable(ctx, tx, tablename) | ||
}) | ||
check.HasError(t, err) | ||
var pgErr *pgconn.PgError | ||
ok := errors.As(err, &pgErr) | ||
check.Bool(t, ok, true) | ||
check.Equal(t, pgErr.Code, "42P07") // duplicate_table | ||
// List migrations. There should be none. | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
res, err := store.ListMigrationsConn(ctx, conn) | ||
check.NoError(t, err) | ||
check.Number(t, len(res), 0) | ||
return nil | ||
}) | ||
check.NoError(t, err) | ||
// Insert 5 migrations in addition to the zero migration. | ||
for i := 0; i < 6; i++ { | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
return store.InsertOrDelete(ctx, conn, true, int64(i)) | ||
}) | ||
check.NoError(t, err) | ||
} | ||
// List migrations. There should be 6. | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
res, err := store.ListMigrationsConn(ctx, conn) | ||
check.NoError(t, err) | ||
check.Number(t, len(res), 6) | ||
// Check versions are in descending order. | ||
for i := 0; i < 6; i++ { | ||
check.Number(t, res[i].Version, 5-i) | ||
} | ||
return nil | ||
}) | ||
check.NoError(t, err) | ||
// Delete 3 migrations backwards | ||
for i := 5; i >= 3; i-- { | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
return store.InsertOrDelete(ctx, conn, false, int64(i)) | ||
}) | ||
check.NoError(t, err) | ||
} | ||
// List migrations. There should be 3. | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
res, err := store.ListMigrationsConn(ctx, conn) | ||
check.NoError(t, err) | ||
check.Number(t, len(res), 3) | ||
// Check that the remaining versions are in descending order. | ||
for i := 0; i < 3; i++ { | ||
check.Number(t, res[i].Version, 2-i) | ||
} | ||
return nil | ||
}) | ||
check.NoError(t, err) | ||
// Get remaining migrations one by one. | ||
for i := 0; i < 3; i++ { | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
res, err := store.GetMigrationConn(ctx, conn, int64(i)) | ||
check.NoError(t, err) | ||
check.Equal(t, res.IsApplied, true) | ||
check.Equal(t, res.Timestamp.IsZero(), false) | ||
return nil | ||
}) | ||
check.NoError(t, err) | ||
} | ||
// Delete remaining migrations one by one and use all 3 connection types: | ||
// *sql.DB | ||
// *sql.Tx | ||
// *sql.Conn. | ||
err = runTx(ctx, db, func(tx *sql.Tx) error { | ||
return store.InsertOrDelete(ctx, tx, false, 2) // *sql.Tx | ||
}) | ||
check.NoError(t, err) | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
return store.InsertOrDelete(ctx, conn, false, 1) // *sql.Conn | ||
}) | ||
check.NoError(t, err) | ||
err = store.InsertOrDelete(ctx, db, false, 0) // *sql.DB | ||
check.NoError(t, err) | ||
// List migrations. There should be none. | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
res, err := store.ListMigrationsConn(ctx, conn) | ||
check.NoError(t, err) | ||
check.Number(t, len(res), 0) | ||
return nil | ||
}) | ||
check.NoError(t, err) | ||
// Try to get a migration that does not exist. | ||
err = runConn(ctx, db, func(conn *sql.Conn) error { | ||
_, err := store.GetMigrationConn(ctx, conn, 0) | ||
check.HasError(t, err) | ||
check.Bool(t, errors.Is(err, sql.ErrNoRows), true) | ||
return nil | ||
}) | ||
check.NoError(t, err) | ||
} | ||
|
||
func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { | ||
tx, err := db.BeginTx(ctx, nil) | ||
if err != nil { | ||
return err | ||
} | ||
defer func() { | ||
if retErr != nil { | ||
retErr = errors.Join(retErr, tx.Rollback()) | ||
} | ||
}() | ||
if err := fn(tx); err != nil { | ||
return err | ||
} | ||
return tx.Commit() | ||
} | ||
|
||
func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) { | ||
conn, err := db.Conn(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
defer func() { | ||
if retErr != nil { | ||
retErr = errors.Join(retErr, conn.Close()) | ||
} | ||
}() | ||
if err := fn(conn); err != nil { | ||
return err | ||
} | ||
return conn.Close() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters