Skip to content

Commit

Permalink
add sqladapter package with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Sep 16, 2023
1 parent 62e1b6a commit 77dec20
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 6 deletions.
56 changes: 56 additions & 0 deletions internal/sqladapter/sqladapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Package sqladapter provides an interface for interacting with a SQL database.
//
// All supported database dialects must implement the Store interface.
package sqladapter

import (
"context"
"database/sql"
"time"
)

// DBTxConn is an interface that is satisfied by *sql.DB, *sql.Tx and *sql.Conn.
//
// There is a long outstanding issue to formalize a std lib interface, but alas...
// See: https://github.com/golang/go/issues/14468
type DBTxConn interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

// Store is the interface that wraps the basic methods for a database dialect.
//
// A dialect is a set of SQL statements that are specific to a database.
//
// By defining a store interface, we can support multiple databases with a single codebase.
//
// The underlying implementation does not modify the error. It is the callers responsibility to
// assert for the correct error, such as sql.ErrNoRows.
type Store interface {
// CreateVersionTable creates the version table within a transaction. This table is used to
// record applied migrations.
CreateVersionTable(ctx context.Context, tx *sql.Tx, tablename string) error

// InsertOrDelete inserts or deletes a version id from the version table.
InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error

// GetMigration retrieves a single migration by version id.
//
// Returns the raw sql error if the query fails. It is the callers responsibility
// to assert for the correct error, such as sql.ErrNoRows.
GetMigrationConn(ctx context.Context, conn *sql.Conn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id.
//
// If there are no migrations, an empty slice is returned with no error.
ListMigrationsConn(ctx context.Context, conn *sql.Conn) ([]*ListMigrationsResult, error)
}

type GetMigrationResult struct {
IsApplied bool
Timestamp time.Time
}

type ListMigrationsResult struct {
Version int64
IsApplied bool
}
111 changes: 111 additions & 0 deletions internal/sqladapter/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package sqladapter

import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/pressly/goose/v3/internal/dialect/dialectquery"
)

var _ Store = (*store)(nil)

type store struct {
tablename string
querier dialectquery.Querier
}

// NewStore returns a new Store backed by the given dialect.
//
// The dialect must match one of the supported dialects defined in dialect.go.
func NewStore(dialect string, table string) (Store, error) {
if table == "" {
return nil, errors.New("table must not be empty")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
var querier dialectquery.Querier
switch dialect {
case "clickhouse":
querier = &dialectquery.Clickhouse{}
case "mssql":
querier = &dialectquery.Sqlserver{}
case "mysql":
querier = &dialectquery.Mysql{}
case "postgres":
querier = &dialectquery.Postgres{}
case "redshift":
querier = &dialectquery.Redshift{}
case "sqlite3":
querier = &dialectquery.Sqlite3{}
case "tidb":
querier = &dialectquery.Tidb{}
case "vertica":
querier = &dialectquery.Vertica{}
default:
return nil, fmt.Errorf("unknown dialect: %q", dialect)
}
return &store{
tablename: table,
querier: querier,
}, nil
}

func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tablename string) error {
q := s.querier.CreateTable(s.tablename)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", tablename, err)
}
return nil
}

func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error {
if direction {
q := s.querier.InsertVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", version, err)
}
return nil
}
q := s.querier.DeleteVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version); err != nil {
return fmt.Errorf("failed to delete version %d: %w", version, err)
}
return nil
}

func (s *store) GetMigrationConn(ctx context.Context, conn *sql.Conn, version int64) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tablename)
result := new(GetMigrationResult)
if err := conn.QueryRowContext(ctx, q, version).Scan(
&result.Timestamp,
&result.IsApplied,
); err != nil {
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
}
return result, nil
}

func (s *store) ListMigrationsConn(ctx context.Context, conn *sql.Conn) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tablename)
rows, err := conn.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("failed to list migrations: %w", err)
}
defer rows.Close()

var migrations []*ListMigrationsResult
for rows.Next() {
result := new(ListMigrationsResult)
if err := rows.Scan(&result.Version, &result.IsApplied); err != nil {
return nil, fmt.Errorf("failed to scan list migrations result: %w", err)
}
migrations = append(migrations, result)
}
if err := rows.Err(); err != nil {
return nil, err
}
return migrations, nil
}
167 changes: 167 additions & 0 deletions internal/sqladapter/store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package sqladapter_test

import (
"context"
"database/sql"
"errors"
"testing"

"github.com/jackc/pgx/v5/pgconn"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqladapter"
"github.com/pressly/goose/v3/internal/testdb"
)

// The goal of this test is to verify the sqladapter package works as expected. This test is not
// meant to be exhaustive or test every possible database dialect. It is meant to verify that the
// Store interface works against a real database.

func TestStore_Postgres(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skip long-running test")
}
ctx := context.Background()
const (
tablename = "test_goose_db_version"
)
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
store, err := sqladapter.NewStore(string(goose.DialectPostgres), tablename)
check.NoError(t, err)
// Create the version table.
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.CreateVersionTable(ctx, tx, tablename)
})
// Create the version table again. This should fail.
check.NoError(t, err)
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.CreateVersionTable(ctx, tx, tablename)
})
check.HasError(t, err)
var pgErr *pgconn.PgError
ok := errors.As(err, &pgErr)
check.Bool(t, ok, true)
check.Equal(t, pgErr.Code, "42P07") // duplicate_table
// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrationsConn(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 0)
return nil
})
check.NoError(t, err)
// Insert 5 migrations in addition to the zero migration.
for i := 0; i < 6; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.InsertOrDelete(ctx, conn, true, int64(i))
})
check.NoError(t, err)
}
// List migrations. There should be 6.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrationsConn(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 6)
// Check versions are in descending order.
for i := 0; i < 6; i++ {
check.Number(t, res[i].Version, 5-i)
}
return nil
})
check.NoError(t, err)
// Delete 3 migrations backwards
for i := 5; i >= 3; i-- {
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.InsertOrDelete(ctx, conn, false, int64(i))
})
check.NoError(t, err)
}
// List migrations. There should be 3.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrationsConn(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 3)
// Check that the remaining versions are in descending order.
for i := 0; i < 3; i++ {
check.Number(t, res[i].Version, 2-i)
}
return nil
})
check.NoError(t, err)
// Get remaining migrations one by one.
for i := 0; i < 3; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.GetMigrationConn(ctx, conn, int64(i))
check.NoError(t, err)
check.Equal(t, res.IsApplied, true)
check.Equal(t, res.Timestamp.IsZero(), false)
return nil
})
check.NoError(t, err)
}
// Delete remaining migrations one by one and use all 3 connection types:
// *sql.DB
// *sql.Tx
// *sql.Conn.
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.InsertOrDelete(ctx, tx, false, 2) // *sql.Tx
})
check.NoError(t, err)
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.InsertOrDelete(ctx, conn, false, 1) // *sql.Conn
})
check.NoError(t, err)
err = store.InsertOrDelete(ctx, db, false, 0) // *sql.DB
check.NoError(t, err)
// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrationsConn(ctx, conn)
check.NoError(t, err)
check.Number(t, len(res), 0)
return nil
})
check.NoError(t, err)
// Try to get a migration that does not exist.
err = runConn(ctx, db, func(conn *sql.Conn) error {
_, err := store.GetMigrationConn(ctx, conn, 0)
check.HasError(t, err)
check.Bool(t, errors.Is(err, sql.ErrNoRows), true)
return nil
})
check.NoError(t, err)
}

func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if retErr != nil {
retErr = errors.Join(retErr, tx.Rollback())
}
}()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}

func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) {
conn, err := db.Conn(ctx)
if err != nil {
return err
}
defer func() {
if retErr != nil {
retErr = errors.Join(retErr, conn.Close())
}
}()
if err := fn(conn); err != nil {
return err
}
return conn.Close()
}
18 changes: 12 additions & 6 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ import (
"database/sql"
"errors"
"time"

"github.com/pressly/goose/v3/internal/sqladapter"
)

// Provider is a goose migration provider.
type Provider struct {
db *sql.DB
opt *ProviderOptions
db *sql.DB
store sqladapter.Store
opt *ProviderOptions
}

// NewProvider returns a new goose Provider.
Expand All @@ -35,14 +38,17 @@ func NewProvider(dialect Dialect, db *sql.DB, opts *ProviderOptions) (*Provider,
if err := validateOptions(opts); err != nil {
return nil, err
}
//
store, err := sqladapter.NewStore(string(dialect), opts.Tablename)
if err != nil {
return nil, err
}
// TODO(mf): implement the rest of this function
// - db / dialect store
// - collect sources
// - merge sources into migrations
return &Provider{
db: db,
opt: opts,
db: db,
store: store,
opt: opts,
}, nil
}

Expand Down

0 comments on commit 77dec20

Please sign in to comment.