Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: defer http server startup after minimal schema upgrade #564

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions cmd/buckets_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,22 @@ func NewBucketUpgrade() *cobra.Command {
Args: cobra.ExactArgs(1),
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
connectionOptions, err := bunconnect.ConnectionOptionsFromFlags(cmd)
if err != nil {
return err
}
logger := logging.NewDefaultLogger(cmd.OutOrStdout(), service.IsDebug(cmd), false, false)
cmd.SetContext(logging.ContextWithLogger(cmd.Context(), logger))

db, err := bunconnect.OpenSQLDB(cmd.Context(), *connectionOptions)
driver, err := getDriver(cmd)
if err != nil {
return err
}
defer func() {
_ = db.Close()
_ = driver.GetDB().Close()
}()

driver := driver.New(db)
if err := driver.Initialize(cmd.Context()); err != nil {
return err
}

if args[0] == "*" {
return upgradeAll(cmd)
return driver.UpgradeAllBuckets(cmd.Context(), make(chan struct{}))
}

logger := logging.NewDefaultLogger(cmd.OutOrStdout(), service.IsDebug(cmd), false, false)

return driver.UpgradeBucket(logging.ContextWithLogger(cmd.Context(), logger), args[0])
return driver.UpgradeBucket(cmd.Context(), args[0])
},
}

Expand All @@ -48,27 +39,22 @@ func NewBucketUpgrade() *cobra.Command {
return cmd
}

func upgradeAll(cmd *cobra.Command) error {
logger := logging.NewDefaultLogger(cmd.OutOrStdout(), service.IsDebug(cmd), false, false)
ctx := logging.ContextWithLogger(cmd.Context(), logger)
func getDriver(cmd *cobra.Command) (*driver.Driver, error) {

connectionOptions, err := bunconnect.ConnectionOptionsFromFlags(cmd)
if err != nil {
return err
return nil, err
}

db, err := bunconnect.OpenSQLDB(cmd.Context(), *connectionOptions)
if err != nil {
return err
return nil, err
}
defer func() {
_ = db.Close()
}()

driver := driver.New(db)
if err := driver.Initialize(ctx); err != nil {
return err
if err := driver.Initialize(cmd.Context()); err != nil {
return nil, err
}

return driver.UpgradeAllBuckets(ctx)
return driver, nil
}
10 changes: 9 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ func NewRootCommand() *cobra.Command {
root.AddCommand(version)
root.AddCommand(bunmigrate.NewDefaultCommand(func(cmd *cobra.Command, _ []string, _ *bun.DB) error {
// todo: use provided db ...
return upgradeAll(cmd)
driver, err := getDriver(cmd)
if err != nil {
return err
}
defer func() {
_ = driver.GetDB().Close()
}()

return driver.UpgradeAllBuckets(cmd.Context(), make(chan struct{}))
}))
root.AddCommand(NewDocsCommand())

Expand Down
4 changes: 2 additions & 2 deletions internal/storage/bucket/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ type Bucket struct {
db *bun.DB
}

func (b *Bucket) Migrate(ctx context.Context, tracer trace.Tracer) error {
return migrate(ctx, tracer, b.db, b.name)
func (b *Bucket) Migrate(ctx context.Context, tracer trace.Tracer, minimalVersionReached chan struct{}) error {
return migrate(ctx, tracer, b.db, b.name, minimalVersionReached)
}

func (b *Bucket) HasMinimalVersion(ctx context.Context) (bool, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/storage/bucket/bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ func TestBuckets(t *testing.T) {
require.NoError(t, driver.Migrate(ctx, db))

b := bucket.New(db, name)
require.NoError(t, b.Migrate(ctx, noop.Tracer{}))
require.NoError(t, b.Migrate(ctx, noop.Tracer{}, make(chan struct{})))
}
35 changes: 33 additions & 2 deletions internal/storage/bucket/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bucket
import (
"context"
"embed"
"errors"
"github.com/formancehq/go-libs/v2/migrations"
"github.com/uptrace/bun"
"go.opentelemetry.io/otel/trace"
Expand All @@ -22,9 +23,39 @@ func GetMigrator(db *bun.DB, name string) *migrations.Migrator {
return migrator
}

func migrate(ctx context.Context, tracer trace.Tracer, db *bun.DB, name string) error {
func migrate(ctx context.Context, tracer trace.Tracer, db *bun.DB, name string, minimalVersionReached chan struct{}) error {
ctx, span := tracer.Start(ctx, "Migrate bucket")
defer span.End()

return GetMigrator(db, name).Up(ctx)
migrator := GetMigrator(db, name)
version, err := migrator.GetLastVersion(ctx)
if err != nil {
if !errors.Is(err, migrations.ErrMissingVersionTable) {
return err
}
}

if version >= MinimalSchemaVersion {
close(minimalVersionReached)
}

for {
err := migrator.UpByOne(ctx)
if err != nil {
if errors.Is(err, migrations.ErrAlreadyUpToDate) {
return nil
}
return err
}
version++

if version >= MinimalSchemaVersion {
select {
case <-minimalVersionReached:
// already closed
default:
close(minimalVersionReached)
}
}
}
}
50 changes: 41 additions & 9 deletions internal/storage/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
noopmetrics "go.opentelemetry.io/otel/metric/noop"
"go.opentelemetry.io/otel/trace"
nooptracer "go.opentelemetry.io/otel/trace/noop"
"golang.org/x/sync/errgroup"

ledgercontroller "github.com/formancehq/ledger/internal/controller/ledger"

Expand Down Expand Up @@ -40,7 +41,7 @@ func (d *Driver) CreateLedger(ctx context.Context, l *ledger.Ledger) (*ledgersto
}

b := bucket.New(d.db, l.Bucket)
if err := b.Migrate(ctx, d.tracer); err != nil {
if err := b.Migrate(ctx, d.tracer, make(chan struct{})); err != nil {
return nil, fmt.Errorf("migrating bucket: %w", err)
}

Expand Down Expand Up @@ -188,10 +189,13 @@ func (d *Driver) GetLedger(ctx context.Context, name string) (*ledger.Ledger, er
}

func (d *Driver) UpgradeBucket(ctx context.Context, name string) error {
return bucket.New(d.db, name).Migrate(ctx, d.tracer)
if err := bucket.New(d.db, name).Migrate(ctx, d.tracer, make(chan struct{})); err != nil {
return fmt.Errorf("migrating bucket '%s': %w", name, err)
}
return nil
}

func (d *Driver) UpgradeAllBuckets(ctx context.Context) error {
func (d *Driver) UpgradeAllBuckets(ctx context.Context, minimalVersionReached chan struct{}) error {

var buckets []string
err := d.db.NewSelect().
Expand All @@ -203,17 +207,45 @@ func (d *Driver) UpgradeAllBuckets(ctx context.Context) error {
return fmt.Errorf("getting buckets: %w", err)
}

sem := make(chan struct{}, len(buckets))

grp, ctx := errgroup.WithContext(ctx)
for _, bucketName := range buckets {
b := bucket.New(d.db, bucketName)
grp.Go(func() error {
b := bucket.New(d.db, bucketName)

minimalVersionReached := make(chan struct{})

go func() {
select {
case <-ctx.Done():
return
case <-minimalVersionReached:
sem <- struct{}{}
}
}()

logging.FromContext(ctx).Infof("Upgrading bucket '%s'", bucketName)
if err := b.Migrate(ctx, d.tracer, minimalVersionReached); err != nil {
return err
}
logging.FromContext(ctx).Infof("Bucket '%s' up to date", bucketName)

return nil
})
}

logging.FromContext(ctx).Infof("Upgrading bucket '%s'", bucketName)
if err := b.Migrate(ctx, d.tracer); err != nil {
return err
for i := 0; i < len(buckets); i++ {
select {
case <-ctx.Done():
return ctx.Err()
case <-sem:
}
logging.FromContext(ctx).Infof("Bucket '%s' up to date", bucketName)
}

return nil
close(minimalVersionReached)

return grp.Wait()
}

func (d *Driver) GetDB() *bun.DB {
Expand Down
2 changes: 1 addition & 1 deletion internal/storage/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestUpgradeAllLedgers(t *testing.T) {
require.NoError(t, err)
}

require.NoError(t, d.UpgradeAllBuckets(ctx))
require.NoError(t, d.UpgradeAllBuckets(ctx, make(chan struct{})))
}

func TestLedgersCreate(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/storage/ledger/legacy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func newLedgerStore(t T) *testStore {
l.Bucket = ledgerName

b := bucket.New(db, ledgerName)
require.NoError(t, b.Migrate(ctx, noop.Tracer{}))
require.NoError(t, b.Migrate(ctx, noop.Tracer{}, make(chan struct{})))
require.NoError(t, b.AddLedger(ctx, l, db))

return &testStore{
Expand Down
54 changes: 32 additions & 22 deletions internal/storage/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func NewFXModule(autoUpgrade bool) fx.Option {
upgradeContext context.Context
cancelContext func()
upgradeStopped = make(chan struct{})
minimalVersionReached = make(chan struct{})
)
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
Expand All @@ -29,29 +30,15 @@ func NewFXModule(autoUpgrade bool) fx.Option {
go func() {
defer close(upgradeStopped)

for {
select {
case <-ctx.Done():
return
default:
logging.FromContext(ctx).Infof("Upgrading buckets...")
if err := driver.UpgradeAllBuckets(upgradeContext); err != nil {
// Long migrations can be cancelled (app rescheduled for example)
// before fully terminated, handle this gracefully, don't panic,
// the next start will try again.
if errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) {
return
}
logging.FromContext(ctx).Errorf("Upgrading buckets: %s", err)
continue
}
return
}
}

migrate(upgradeContext, driver, minimalVersionReached)
}()
return nil

select {
case <-ctx.Done():
return ctx.Err()
case <-minimalVersionReached:
return nil
}
},
OnStop: func(ctx context.Context) error {
cancelContext()
Expand All @@ -68,3 +55,26 @@ func NewFXModule(autoUpgrade bool) fx.Option {
}
return fx.Options(ret...)
}

func migrate(ctx context.Context, driver *driver.Driver, minimalVersionReached chan struct{}) {
for {
select {
case <-ctx.Done():
return
default:
logging.FromContext(ctx).Infof("Upgrading buckets...")
if err := driver.UpgradeAllBuckets(ctx, minimalVersionReached); err != nil {
// Long migrations can be cancelled (app rescheduled for example)
// before fully terminated, handle this gracefully, don't panic,
// the next start will try again.
if errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) {
return
}
logging.FromContext(ctx).Errorf("Upgrading buckets: %s", err)
continue
}
return
}
}
}
2 changes: 1 addition & 1 deletion test/migrations/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestMigrations(t *testing.T) {
// Migrate database
driver := driver.New(db)
require.NoError(t, driver.Initialize(ctx))
require.NoError(t, driver.UpgradeAllBuckets(ctx))
require.NoError(t, driver.UpgradeAllBuckets(ctx, make(chan struct{})))
}

func copyDatabase(t *testing.T, dockerPool *docker.Pool, source, destination string) {
Expand Down
Loading