From 7a854713afe96d6247c23ec1ac9a38eab928a1fd Mon Sep 17 00:00:00 2001 From: Bill Moran Date: Fri, 10 May 2024 08:01:22 -0400 Subject: [PATCH] Refactor global DBSteward variable to passed variable --- lib/dbsteward.go | 60 +++-- lib/dbsteward_main_test.go | 24 -- lib/format/interface.go | 44 ---- lib/format/lookup.go | 10 - lib/format/pgsql8/constraint.go | 10 +- lib/format/pgsql8/diff.go | 216 +++++++++--------- lib/format/pgsql8/diff_constraints.go | 40 ++-- lib/format/pgsql8/diff_constraints_test.go | 28 ++- lib/format/pgsql8/diff_functions.go | 11 +- lib/format/pgsql8/diff_languages.go | 22 +- lib/format/pgsql8/diff_sequences.go | 7 +- lib/format/pgsql8/diff_tables.go | 112 +++++---- .../pgsql8/diff_tables_escape_char_test.go | 22 +- lib/format/pgsql8/diff_tables_test.go | 84 ++++--- lib/format/pgsql8/diff_types.go | 8 +- lib/format/pgsql8/diff_types_domains_test.go | 14 +- lib/format/pgsql8/diff_types_test.go | 11 +- lib/format/pgsql8/diff_views.go | 13 +- lib/format/pgsql8/diff_views_test.go | 9 +- lib/format/pgsql8/function.go | 9 +- lib/format/pgsql8/language.go | 6 +- lib/format/pgsql8/oneeighty_test.go | 7 +- lib/format/pgsql8/operations.go | 93 ++++---- .../operations_column_value_default_test.go | 10 +- .../pgsql8/operations_extract_schema_test.go | 33 ++- lib/format/pgsql8/pgsql8.go | 5 +- lib/format/pgsql8/pgsql8_main_test.go | 22 +- lib/format/pgsql8/role.go | 5 +- lib/format/pgsql8/schema.go | 16 +- lib/format/pgsql8/sequence.go | 11 +- lib/format/pgsql8/table.go | 17 +- lib/format/pgsql8/table_test.go | 4 +- lib/format/pgsql8/type.go | 5 +- lib/format/pgsql8/view.go | 9 +- lib/format/pgsql8/xml_parser.go | 20 +- .../pgsql8/xml_parser_partition_modulo.go | 6 +- lib/format/pgsql8/xml_parser_test.go | 10 +- lib/format/sql99/operations.go | 18 -- main.go | 7 +- xmlpostgresintegration_test.go | 11 +- 40 files changed, 511 insertions(+), 558 deletions(-) delete mode 100644 lib/dbsteward_main_test.go delete mode 100644 lib/format/interface.go delete mode 100644 lib/format/lookup.go delete mode 100644 lib/format/sql99/operations.go diff --git a/lib/dbsteward.go b/lib/dbsteward.go index d2547dd..5e8c830 100644 --- a/lib/dbsteward.go +++ b/lib/dbsteward.go @@ -10,8 +10,8 @@ import ( "github.com/dbsteward/dbsteward/lib/config" "github.com/dbsteward/dbsteward/lib/encoding/xml" - "github.com/dbsteward/dbsteward/lib/format" "github.com/dbsteward/dbsteward/lib/ir" + "github.com/dbsteward/dbsteward/lib/output" "github.com/dbsteward/dbsteward/lib/util" "github.com/hashicorp/go-multierror" @@ -19,19 +19,47 @@ import ( "github.com/rs/zerolog" ) +type LookupMap map[ir.SqlFormat]*Lookup + +type Lookup struct { + Schema Schema + OperationsConstructor func(*DBSteward) Operations +} + +type Operations interface { + Build(outputPrefix string, dbDoc *ir.Definition) error + BuildUpgrade( + oldOutputPrefix, oldCompositeFile string, oldDbDoc *ir.Definition, oldFiles []string, + newOutputPrefix, newCompositeFile string, newDbDoc *ir.Definition, newFiles []string, + ) error + ExtractSchema(host string, port uint, name, user, pass string) (*ir.Definition, error) + CompareDbData(dbDoc *ir.Definition, host string, port uint, name, user, pass string) (*ir.Definition, error) + SqlDiff(old, new []string, outputFile string) + + GetQuoter() output.Quoter + //SetConfig(*config.Args) +} + +type Schema interface { + GetCreationSql(*DBSteward, *ir.Schema) ([]output.ToSql, error) + GetDropSql(*ir.Schema) []output.ToSql +} + +type SlonyOperations interface { + SlonyCompare(file string) + SlonyDiff(oldFile, newFile string) +} + // NOTE: 2.0.0 is the intended golang release. 3.0.0 is the intended refactor/modernization -var Version = "2.0.0" +const Version = "2.0.0" // NOTE: we're attempting to maintain "api" compat with legacy dbsteward for now -var ApiVersion = "1.4" - -// TODO(go,3) no globals -var GlobalDBSteward *DBSteward +const ApiVersion = "1.4" type DBSteward struct { logger zerolog.Logger slogLogger *slog.Logger - lookupMap format.LookupMap + lookupMap LookupMap SqlFormat ir.SqlFormat @@ -67,7 +95,7 @@ type DBSteward struct { NewDatabase *ir.Definition } -func NewDBSteward(lookupMap format.LookupMap) *DBSteward { +func NewDBSteward(lookupMap LookupMap) *DBSteward { dbsteward := &DBSteward{ logger: zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Logger(), lookupMap: lookupMap, @@ -108,7 +136,7 @@ func NewDBSteward(lookupMap format.LookupMap) *DBSteward { return dbsteward } -func (dbsteward *DBSteward) Lookup() *format.Lookup { +func (dbsteward *DBSteward) Lookup() *Lookup { return dbsteward.lookupMap[dbsteward.SqlFormat] } @@ -554,7 +582,7 @@ func (dbsteward *DBSteward) doBuild(files []string, dataFiles []string, addendum dbsteward.fatalIfError(err, "saving file") } - err = dbsteward.Lookup().OperationsConstructor().Build(outputPrefix, dbDoc) + err = dbsteward.Lookup().OperationsConstructor(dbsteward).Build(outputPrefix, dbDoc) dbsteward.fatalIfError(err, "building") } func (dbsteward *DBSteward) doDiff(oldFiles []string, newFiles []string, dataFiles []string) { @@ -585,14 +613,14 @@ func (dbsteward *DBSteward) doDiff(oldFiles []string, newFiles []string, dataFil err = xml.SaveDefinition(dbsteward.Logger(), newCompositeFile, newDbDoc) dbsteward.fatalIfError(err, "saving file") - err = dbsteward.Lookup().OperationsConstructor().BuildUpgrade( + err = dbsteward.Lookup().OperationsConstructor(dbsteward).BuildUpgrade( oldOutputPrefix, oldCompositeFile, oldDbDoc, oldFiles, newOutputPrefix, newCompositeFile, newDbDoc, newFiles, ) dbsteward.fatalIfError(err, "building upgrade") } func (dbsteward *DBSteward) doExtract(dbHost string, dbPort uint, dbName, dbUser, dbPass string, outputFile string) { - output, err := dbsteward.Lookup().OperationsConstructor().ExtractSchema(dbHost, dbPort, dbName, dbUser, dbPass) + output, err := dbsteward.Lookup().OperationsConstructor(dbsteward).ExtractSchema(dbHost, dbPort, dbName, dbUser, dbPass) dbsteward.fatalIfError(err, "extracting") dbsteward.Info("Saving extracted database schema to %s", outputFile) err = xml.SaveDefinition(dbsteward.Logger(), outputFile, output) @@ -621,13 +649,13 @@ func (dbsteward *DBSteward) doDbDataDiff(files []string, dataFiles []string, add err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, dbDoc) dbsteward.fatalIfError(err, "saving file") - output, err := dbsteward.Lookup().OperationsConstructor().CompareDbData(dbDoc, dbHost, dbPort, dbName, dbUser, dbPass) + output, err := dbsteward.Lookup().OperationsConstructor(dbsteward).CompareDbData(dbDoc, dbHost, dbPort, dbName, dbUser, dbPass) dbsteward.fatalIfError(err, "comparing data") err = xml.SaveDefinition(dbsteward.Logger(), compositeFile, output) dbsteward.fatalIfError(err, "saving file") } func (dbsteward *DBSteward) doSqlDiff(oldSql, newSql []string, outputFile string) { - dbsteward.Lookup().OperationsConstructor().SqlDiff(oldSql, newSql, outputFile) + dbsteward.Lookup().OperationsConstructor(dbsteward).SqlDiff(oldSql, newSql, outputFile) } func (dbsteward *DBSteward) doSlonikConvert(file string, outputFile string) { // TODO(go,nth) is there a nicer way to handle this output idiom? @@ -640,8 +668,8 @@ func (dbsteward *DBSteward) doSlonikConvert(file string, outputFile string) { } } func (dbsteward *DBSteward) doSlonyCompare(file string) { - dbsteward.lookupMap[ir.SqlFormatPgsql8].OperationsConstructor().(format.SlonyOperations).SlonyCompare(file) + dbsteward.lookupMap[ir.SqlFormatPgsql8].OperationsConstructor(dbsteward).(SlonyOperations).SlonyCompare(file) } func (dbsteward *DBSteward) doSlonyDiff(oldFile string, newFile string) { - dbsteward.lookupMap[ir.SqlFormatPgsql8].OperationsConstructor().(format.SlonyOperations).SlonyDiff(oldFile, newFile) + dbsteward.lookupMap[ir.SqlFormatPgsql8].OperationsConstructor(dbsteward).(SlonyOperations).SlonyDiff(oldFile, newFile) } diff --git a/lib/dbsteward_main_test.go b/lib/dbsteward_main_test.go deleted file mode 100644 index f6fbca3..0000000 --- a/lib/dbsteward_main_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package lib_test - -import ( - "os" - "testing" - - "github.com/dbsteward/dbsteward/lib/format/pgsql8" - - "github.com/dbsteward/dbsteward/lib" - "github.com/dbsteward/dbsteward/lib/format" - "github.com/dbsteward/dbsteward/lib/ir" -) - -func TestMain(m *testing.M) { - resetGlobalDBSteward() - os.Exit(m.Run()) -} - -func resetGlobalDBSteward() { - lib.GlobalDBSteward = lib.NewDBSteward(format.LookupMap{ - ir.SqlFormatPgsql8: pgsql8.GlobalLookup, - }) - lib.GlobalDBSteward.SqlFormat = ir.SqlFormatPgsql8 -} diff --git a/lib/format/interface.go b/lib/format/interface.go deleted file mode 100644 index 4d465d5..0000000 --- a/lib/format/interface.go +++ /dev/null @@ -1,44 +0,0 @@ -package format - -import ( - "github.com/dbsteward/dbsteward/lib/config" - "github.com/dbsteward/dbsteward/lib/ir" - "github.com/dbsteward/dbsteward/lib/output" -) - -type Operations interface { - Build(outputPrefix string, dbDoc *ir.Definition) error - BuildUpgrade( - oldOutputPrefix, oldCompositeFile string, oldDbDoc *ir.Definition, oldFiles []string, - newOutputPrefix, newCompositeFile string, newDbDoc *ir.Definition, newFiles []string, - ) error - ExtractSchema(host string, port uint, name, user, pass string) (*ir.Definition, error) - CompareDbData(dbDoc *ir.Definition, host string, port uint, name, user, pass string) (*ir.Definition, error) - SqlDiff(old, new []string, outputFile string) - - GetQuoter() output.Quoter - SetConfig(*config.Args) -} - -type SlonyOperations interface { - SlonyCompare(file string) - SlonyDiff(oldFile, newFile string) -} - -type Schema interface { - GetCreationSql(*ir.Schema) ([]output.ToSql, error) - GetDropSql(*ir.Schema) []output.ToSql -} - -type Index interface { - BuildPrimaryKeyName(string) string - BuildForeignKeyName(string, string) string -} - -type Diff interface { - DiffDoc(oldFile, newFile string, oldDoc, newDoc *ir.Definition, upgradePrefix string) error - DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegmenter) error - - DropOldSchemas(output.OutputFileSegmenter) - CreateNewSchemas(output.OutputFileSegmenter) error -} diff --git a/lib/format/lookup.go b/lib/format/lookup.go deleted file mode 100644 index 231d4c5..0000000 --- a/lib/format/lookup.go +++ /dev/null @@ -1,10 +0,0 @@ -package format - -import "github.com/dbsteward/dbsteward/lib/ir" - -type LookupMap map[ir.SqlFormat]*Lookup - -type Lookup struct { - Schema Schema - OperationsConstructor func() Operations -} diff --git a/lib/format/pgsql8/constraint.go b/lib/format/pgsql8/constraint.go index 599fcd4..b721ab0 100644 --- a/lib/format/pgsql8/constraint.go +++ b/lib/format/pgsql8/constraint.go @@ -268,8 +268,8 @@ func getTableContraintCreationSql(constraint *sql99.TableConstraint) []output.To return nil } -func constraintDependsOnRenamedTable(l *slog.Logger, doc *ir.Definition, constraint *sql99.TableConstraint) (bool, error) { - if lib.GlobalDBSteward.IgnoreOldNames { +func constraintDependsOnRenamedTable(dbs *lib.DBSteward, doc *ir.Definition, constraint *sql99.TableConstraint) (bool, error) { + if dbs.IgnoreOldNames { return false, nil } @@ -294,16 +294,16 @@ func constraintDependsOnRenamedTable(l *slog.Logger, doc *ir.Definition, constra if refTable == nil { return false, nil } - isRenamed := lib.GlobalDBSteward.IgnoreOldNames + isRenamed := dbs.IgnoreOldNames if !isRenamed { var err error - isRenamed, err = lib.GlobalDBSteward.OldDatabase.IsRenamedTable(slog.Default(), refSchema, refTable) + isRenamed, err = dbs.OldDatabase.IsRenamedTable(slog.Default(), refSchema, refTable) if err != nil { return false, fmt.Errorf("while checking if constraint depends on renamed table: %w", err) } } if isRenamed { - l.Info(fmt.Sprintf("Constraint %s.%s.%s references renamed table %s.%s", constraint.Schema.Name, constraint.Table.Name, constraint.Name, refSchema.Name, refTable.Name)) + dbs.Logger().Info(fmt.Sprintf("Constraint %s.%s.%s references renamed table %s.%s", constraint.Schema.Name, constraint.Table.Name, constraint.Name, refSchema.Name, refTable.Name)) return true, nil } return false, nil diff --git a/lib/format/pgsql8/diff.go b/lib/format/pgsql8/diff.go index 87e561b..9baed7a 100644 --- a/lib/format/pgsql8/diff.go +++ b/lib/format/pgsql8/diff.go @@ -6,7 +6,6 @@ import ( "os" "time" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/format/sql99" "github.com/dbsteward/dbsteward/lib/ir" @@ -15,12 +14,16 @@ import ( type diff struct { quoter output.Quoter + ops *Operations OldTableDependency []*ir.TableRef NewTableDependency []*ir.TableRef } -func newDiff(q output.Quoter) *diff { - return &diff{quoter: q} +func newDiff(ops *Operations, q output.Quoter) *diff { + return &diff{ + ops: ops, + quoter: q, + } } func (d *diff) Quoter() output.Quoter { @@ -56,31 +59,30 @@ func (d *diff) UpdateDatabaseConfigParameters(ofs output.OutputFileSegmenter, ol } func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegmenter) error { - dbsteward := lib.GlobalDBSteward // this shouldn't be called if we're not generating slonik, it looks for // a slony element in which most likely won't be there if // we're not interested in slony replication - if dbsteward.GenerateSlonik { + if d.ops.dbsteward.GenerateSlonik { // TODO(go,slony) } // stage 1 and 3 should not be in a transaction as they will be submitted via slonik EXECUTE SCRIPT - if !dbsteward.GenerateSlonik { + if !d.ops.dbsteward.GenerateSlonik { stage1.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage1.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) } else { stage1.AppendHeader(sql.NewComment("generateslonik specified: pgsql8 STAGE1 upgrade omitting BEGIN. slonik EXECUTE SCRIPT will wrap stage 1 DDL and DCL in a transaction")) } - if !dbsteward.SingleStageUpgrade { + if !d.ops.dbsteward.SingleStageUpgrade { stage2.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage2.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) stage4.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage4.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) // stage 1 and 3 should not be in a transaction as they will be submitted via slonik EXECUTE SCRIPT - if !dbsteward.GenerateSlonik { + if !d.ops.dbsteward.GenerateSlonik { stage3.AppendHeader(output.NewRawSQL("\nBEGIN;\n\n")) stage3.AppendFooter(output.NewRawSQL("\nCOMMIT;\n")) } else { @@ -89,13 +91,13 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme } // start with pre-upgrade sql statements that prepare the database to take on its changes - buildStagedSql(dbsteward.NewDatabase, stage1, "STAGE1BEFORE") - buildStagedSql(dbsteward.NewDatabase, stage2, "STAGE2BEFORE") + buildStagedSql(d.ops.dbsteward.NewDatabase, stage1, "STAGE1BEFORE") + buildStagedSql(d.ops.dbsteward.NewDatabase, stage2, "STAGE2BEFORE") - d.Logger().Info("Drop Old Schemas") + d.ops.dbsteward.Logger().Info("Drop Old Schemas") d.DropOldSchemas(stage3) - d.Logger().Info("Create New Schemas") + d.ops.dbsteward.Logger().Info("Create New Schemas") err := d.CreateNewSchemas(stage1) if err != nil { return err @@ -106,37 +108,37 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme return err } - d.Logger().Info("Update Permissions") + d.ops.dbsteward.Logger().Info("Update Permissions") err = d.updatePermissions(stage1, stage3) if err != nil { return err } - d.UpdateDatabaseConfigParameters(stage1, dbsteward.NewDatabase, dbsteward.OldDatabase) + d.UpdateDatabaseConfigParameters(stage1, d.ops.dbsteward.NewDatabase, d.ops.dbsteward.OldDatabase) - d.Logger().Info("Update data") - if dbsteward.GenerateSlonik { + d.ops.dbsteward.Logger().Info("Update data") + if d.ops.dbsteward.GenerateSlonik { // TODO(go,slony) format::set_context_replica_set_to_natural_first(dbsteward::$new_database); } - err = d.updateData(d.Logger(), stage2, true) + err = d.updateData(stage2, true) if err != nil { return err } - err = d.updateData(d.Logger(), stage4, false) + err = d.updateData(stage4, false) if err != nil { return err } // append any literal sql in new not in old at the end of data stage 1 // TODO(feat) this relies on exact string match - is there a better way? - for _, newSql := range dbsteward.NewDatabase.Sql { + for _, newSql := range d.ops.dbsteward.NewDatabase.Sql { // ignore upgrade staged sql elements if newSql.Stage != "" { continue } found := false - for _, oldSql := range dbsteward.OldDatabase.Sql { + for _, oldSql := range d.ops.dbsteward.OldDatabase.Sql { // ignore upgrade staged sql elements if oldSql.Stage != "" { continue @@ -154,14 +156,14 @@ func (d *diff) DiffDocWork(stage1, stage2, stage3, stage4 output.OutputFileSegme } // append stage defined sql statements to appropriate stage file - if dbsteward.GenerateSlonik { + if d.ops.dbsteward.GenerateSlonik { // TODO(go,slony) format::set_context_replica_set_to_natural_first(dbsteward::$new_database); } - buildStagedSql(dbsteward.NewDatabase, stage1, "STAGE1") - buildStagedSql(dbsteward.NewDatabase, stage2, "STAGE2") - buildStagedSql(dbsteward.NewDatabase, stage3, "STAGE3") - buildStagedSql(dbsteward.NewDatabase, stage4, "STAGE4") + buildStagedSql(d.ops.dbsteward.NewDatabase, stage1, "STAGE1") + buildStagedSql(d.ops.dbsteward.NewDatabase, stage2, "STAGE2") + buildStagedSql(d.ops.dbsteward.NewDatabase, stage3, "STAGE3") + buildStagedSql(d.ops.dbsteward.NewDatabase, stage4, "STAGE4") return nil } @@ -170,50 +172,54 @@ func (d *diff) DiffSql(old, new []string, upgradePrefix string) { } func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter) error { - d.Logger().Info("Update Structure") - dbsteward := lib.GlobalDBSteward + logger := d.ops.dbsteward.Logger() + logger.Info("Update Structure") - err := diffLanguages(d.Logger(), stage1) + err := diffLanguages(d.ops.dbsteward, stage1) if err != nil { return err } // drop all views in all schemas, regardless whether dependency order is known or not // TODO(go,4) would be so cool if we could parse the view def and only recreate what's required - dropViewsOrdered(stage1, dbsteward.OldDatabase, dbsteward.NewDatabase) + dropViewsOrdered(stage1, d.ops.dbsteward.OldDatabase, d.ops.dbsteward.NewDatabase) // TODO(go,3) should we just always use table deps? if len(d.NewTableDependency) == 0 { - d.Logger().Debug("not using table dependencies") - for _, newSchema := range dbsteward.NewDatabase.Schemas { - oldSchema := dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) - err := diffTypes(d.Logger(), d, stage1, oldSchema, newSchema) + logger.Debug("not using table dependencies") + for _, newSchema := range d.ops.dbsteward.NewDatabase.Schemas { + l := logger.With(slog.String("new schema", newSchema.Name)) + oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + if oldSchema != nil { + l = l.With(slog.String("old schema", oldSchema.Name)) + } + err := diffTypes(d.ops.dbsteward, d, stage1, oldSchema, newSchema) if err != nil { return err } - err = diffFunctions(d.Logger(), stage1, stage3, oldSchema, newSchema) + err = diffFunctions(d.ops.dbsteward, stage1, stage3, oldSchema, newSchema) if err != nil { return err } - err = diffSequences(d.Logger(), stage1, oldSchema, newSchema) + err = diffSequences(d.ops.dbsteward, stage1, oldSchema, newSchema) if err != nil { return fmt.Errorf("while diffing sequences: %w", err) } // remove old constraints before table constraints, so the sql statements succeed - err = dropConstraints(d.Logger(), stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) + err = dropConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) if err != nil { return err } - err = dropConstraints(d.Logger(), stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) + err = dropConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) if err != nil { return err } - dropTables(stage1, oldSchema, newSchema) - err = createTables(d.Logger(), stage1, oldSchema, newSchema) + dropTables(d.ops.dbsteward, stage1, oldSchema, newSchema) + err = createTables(d.ops.dbsteward, stage1, oldSchema, newSchema) if err != nil { return fmt.Errorf("while creating tables: %w", err) } - err = diffTables(d.Logger(), stage1, stage3, oldSchema, newSchema) + err = diffTables(d.ops.dbsteward, stage1, stage3, oldSchema, newSchema) if err != nil { return fmt.Errorf("while diffing tables: %w", err) } @@ -222,7 +228,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. return err } diffClusters(stage1, oldSchema, newSchema) - createConstraints(d.Logger(), stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) + createConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypePrimaryKey) err = diffTriggers(stage1, oldSchema, newSchema) if err != nil { return err @@ -230,25 +236,25 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. } // non-primary key constraints may be inter-schema dependant, and dependant on other's primary keys // and therefore should be done after object creation sections - for _, newSchema := range dbsteward.NewDatabase.Schemas { - oldSchema := dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) - createConstraints(d.Logger(), stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) + for _, newSchema := range d.ops.dbsteward.NewDatabase.Schemas { + oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + createConstraints(d.ops.dbsteward, stage1, oldSchema, newSchema, sql99.ConstraintTypeConstraint) } } else { - d.Logger().Debug("using table dependencies") + logger.Debug("using table dependencies") // use table dependency order to do structural changes in an intelligent order // make sure we only process each schema once processedSchemas := map[string]bool{} for _, newEntry := range d.NewTableDependency { newSchema := newEntry.Schema - oldSchema := dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) if !processedSchemas[newSchema.Name] { - err := diffTypes(d.Logger(), d, stage1, oldSchema, newSchema) + err := diffTypes(d.ops.dbsteward, d, stage1, oldSchema, newSchema) if err != nil { return err } - err = diffFunctions(d.Logger(), stage1, stage3, oldSchema, newSchema) + err = diffFunctions(d.ops.dbsteward, stage1, stage3, oldSchema, newSchema) if err != nil { return err } @@ -262,7 +268,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. oldSchema := oldEntry.Schema oldTable := oldEntry.Table - newSchema := dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) + newSchema := d.ops.dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) var newTable *ir.Table if newSchema != nil { newTable = newSchema.TryGetTableNamed(oldTable.Name) @@ -270,11 +276,11 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. // NOTE: when dropping constraints, GlobalDBX.RenamedTableCheckPointer() is not called for oldTable // as GlobalDiffConstraints.DiffConstraintsTable() will do rename checking when recreating constraints for renamed tables - err := dropConstraintsTable(d.Logger(), stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint) + err := dropConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint) if err != nil { return err } - err = dropConstraintsTable(d.Logger(), stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) + err = dropConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) if err != nil { return err } @@ -283,13 +289,13 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. processedSchemas = map[string]bool{} for _, newEntry := range d.NewTableDependency { newSchema := newEntry.Schema - oldSchema := dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) // schema level stuff should only be done once, keep track of which ones we have done // see above for pre table creation stuff // see below for post table creation stuff if !processedSchemas[newSchema.Name] { - err := diffSequences(d.Logger(), stage1, oldSchema, newSchema) + err := diffSequences(d.ops.dbsteward, stage1, oldSchema, newSchema) if err != nil { return fmt.Errorf("while diffing sequences: %w", err) } @@ -307,15 +313,15 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. // when a table has an oldTableName oldSchemaName specified, // GlobalDBX.RenamedTableCheckPointer() will modify these pointers to be the old table var err error - oldSchema, oldTable, err = lib.GlobalDBSteward.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) + oldSchema, oldTable, err = d.ops.dbsteward.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) if err != nil { return fmt.Errorf("getting new table name: %w", err) } - err = createTable(d.Logger(), stage1, oldSchema, newSchema, newTable) + err = createTable(d.ops.dbsteward, stage1, oldSchema, newSchema, newTable) if err != nil { return fmt.Errorf("while creating table %s.%s: %w", newSchema.Name, newTable.Name, err) } - err = diffTable(d.Logger(), stage1, stage3, oldSchema, oldTable, newSchema, newTable) + err = diffTable(d.ops.dbsteward, stage1, stage3, oldSchema, oldTable, newSchema, newTable) if err != nil { return fmt.Errorf("while diffing table %s.%s: %w", newSchema.Name, newTable.Name, err) } @@ -324,7 +330,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. return err } diffClustersTable(stage1, oldTable, newSchema, newTable) - err = createConstraintsTable(d.Logger(), stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) + err = createConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypePrimaryKey) if err != nil { return err } @@ -335,7 +341,7 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. // HACK: For now, we'll generate foreign key constraints in stage 4 in updateData below // https://github.com/dbsteward/dbsteward/issues/142 - err = createConstraintsTable(d.Logger(), stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint&^sql99.ConstraintTypeForeign) + err = createConstraintsTable(d.ops.dbsteward, stage1, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeConstraint&^sql99.ConstraintTypeForeign) if err != nil { return err } @@ -347,26 +353,30 @@ func (d *diff) updateStructure(stage1 output.OutputFileSegmenter, stage3 output. oldSchema := oldEntry.Schema oldTable := oldEntry.Table - newSchema := dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) - dropTable(stage3, oldSchema, oldTable, newSchema) + newSchema := d.ops.dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) + dropTable(d.ops.dbsteward, stage3, oldSchema, oldTable, newSchema) } } - return createViewsOrdered(d.Logger(), stage3, dbsteward.OldDatabase, dbsteward.NewDatabase) + return createViewsOrdered(d.ops.dbsteward, stage3, d.ops.dbsteward.OldDatabase, d.ops.dbsteward.NewDatabase) } func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter) error { // TODO(feat) what if readonly user changed? we need to rebuild those grants // TODO(feat) what about removed permissions, shouldn't we REVOKE those? - newDoc := lib.GlobalDBSteward.NewDatabase - oldDoc := lib.GlobalDBSteward.OldDatabase + newDoc := d.ops.dbsteward.NewDatabase + oldDoc := d.ops.dbsteward.OldDatabase + logger := d.ops.dbsteward.Logger() for _, newSchema := range newDoc.Schemas { + l := logger.With(slog.String("new schema", newSchema.Name)) oldSchema := oldDoc.TryGetSchemaNamed(newSchema.Name) - + if oldSchema != nil { + l = l.With(slog.String("old schema", oldSchema.Name)) + } for _, newGrant := range newSchema.Grants { if oldSchema == nil || !ir.HasPermissionsOf(oldSchema, newGrant, ir.SqlFormatPgsql8) { - s, err := GlobalSchema.GetGrantSql(newDoc, newSchema, newGrant) + s, err := GlobalSchema.GetGrantSql(d.ops.dbsteward, newDoc, newSchema, newGrant) if err != nil { return err } @@ -376,7 +386,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu for _, newTable := range newSchema.Tables { oldTable := oldSchema.TryGetTableNamed(newTable.Name) - isRenamed, err := lib.GlobalDBSteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := d.ops.dbsteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while updating permissions: %w", err) } @@ -387,7 +397,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu } for _, newGrant := range newTable.Grants { if oldTable == nil || !ir.HasPermissionsOf(oldTable, newGrant, ir.SqlFormatPgsql8) { - s, err := getTableGrantSql(d.Logger(), newSchema, newTable, newGrant) + s, err := getTableGrantSql(d.ops.dbsteward, newSchema, newTable, newGrant) if err != nil { return err } @@ -400,7 +410,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu oldSeq := oldSchema.TryGetSequenceNamed(newSeq.Name) for _, newGrant := range newSeq.Grants { if oldSeq == nil || !ir.HasPermissionsOf(oldSeq, newGrant, ir.SqlFormatPgsql8) { - s, err := getSequenceGrantSql(d.Logger(), newSchema, newSeq, newGrant) + s, err := getSequenceGrantSql(d.ops.dbsteward, newSchema, newSeq, newGrant) if err != nil { return err } @@ -413,7 +423,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu oldFunc := oldSchema.TryGetFunctionMatching(newFunc) for _, newGrant := range newFunc.Grants { if oldFunc == nil || !ir.HasPermissionsOf(oldFunc, newGrant, ir.SqlFormatPgsql8) { - grants, err := getFunctionGrantSql(d.Logger(), newSchema, newFunc, newGrant) + grants, err := getFunctionGrantSql(d.ops.dbsteward, newSchema, newFunc, newGrant) if err != nil { return err } @@ -425,8 +435,8 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu for _, newView := range newSchema.Views { oldView := oldSchema.TryGetViewNamed(newView.Name) for _, newGrant := range newView.Grants { - if lib.GlobalDBSteward.AlwaysRecreateViews || oldView == nil || !ir.HasPermissionsOf(oldView, newGrant, ir.SqlFormatPgsql8) || !oldView.Equals(newView, ir.SqlFormatPgsql8) { - s, err := getViewGrantSql(d.Logger(), newDoc, newSchema, newView, newGrant) + if d.ops.dbsteward.AlwaysRecreateViews || oldView == nil || !ir.HasPermissionsOf(oldView, newGrant, ir.SqlFormatPgsql8) || !oldView.Equals(newView, ir.SqlFormatPgsql8) { + s, err := getViewGrantSql(d.ops.dbsteward, newDoc, newSchema, newView, newGrant) if err != nil { return err } @@ -438,7 +448,7 @@ func (d *diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 outpu return nil } -func (d *diff) updateData(l *slog.Logger, ofs output.OutputFileSegmenter, deleteMode bool) error { +func (d *diff) updateData(ofs output.OutputFileSegmenter, deleteMode bool) error { if len(d.NewTableDependency) > 0 { for i := 0; i < len(d.NewTableDependency); i += 1 { item := d.NewTableDependency[i] @@ -446,32 +456,32 @@ func (d *diff) updateData(l *slog.Logger, ofs output.OutputFileSegmenter, delete if deleteMode { item = d.NewTableDependency[len(d.NewTableDependency)-1-i] } - + l := d.ops.dbsteward.Logger().With(slog.String("table", item.String())) newSchema := item.Schema newTable := item.Table - oldSchema := lib.GlobalDBSteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) oldTable := oldSchema.TryGetTableNamed(newTable.Name) - isRenamed, err := lib.GlobalDBSteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := d.ops.dbsteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while updatign data: %w", err) } if isRenamed { - d.Logger().Info(fmt.Sprintf("%s.%s used to be called %s - will diff data against that definition", newSchema.Name, newTable.Name, newTable.OldTableName)) - oldSchema = lib.GlobalDBSteward.OldDatabase.GetOldTableSchema(newSchema, newTable) - oldTable = lib.GlobalDBSteward.OldDatabase.GetOldTable(newSchema, newTable) + l.Info(fmt.Sprintf("%s.%s used to be called %s - will diff data against that definition", newSchema.Name, newTable.Name, newTable.OldTableName)) + oldSchema = d.ops.dbsteward.OldDatabase.GetOldTableSchema(newSchema, newTable) + oldTable = d.ops.dbsteward.OldDatabase.GetOldTable(newSchema, newTable) } if deleteMode { // TODO(go,3) clean up inconsistencies between e.g. GetDeleteDataSql and DiffData wrt writing sql to an ofs // TODO(feat) aren't deletes supposed to go in stage 2? - s, err := getDeleteDataSql(l, oldSchema, oldTable, newSchema, newTable) + s, err := getDeleteDataSql(d.ops, oldSchema, oldTable, newSchema, newTable) if err != nil { return err } ofs.WriteSql(s...) } else { - s, err := getCreateDataSql(l, oldSchema, oldTable, newSchema, newTable) + s, err := getCreateDataSql(d.ops, oldSchema, oldTable, newSchema, newTable) if err != nil { return err } @@ -479,7 +489,7 @@ func (d *diff) updateData(l *slog.Logger, ofs output.OutputFileSegmenter, delete // HACK: For now, we'll generate foreign key constraints in stage 4 after inserting data // https://github.com/dbsteward/dbsteward/issues/142 - err = createConstraintsTable(d.Logger(), ofs, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeForeign) + err = createConstraintsTable(d.ops.dbsteward, ofs, oldSchema, oldTable, newSchema, newTable, sql99.ConstraintTypeForeign) if err != nil { return err } @@ -488,9 +498,9 @@ func (d *diff) updateData(l *slog.Logger, ofs output.OutputFileSegmenter, delete } else { // dependency order unknown, hit them in natural order // TODO(feat) the above switches on deleteMode, this does not. we never delete data if table dep order is unknown? - for _, newSchema := range lib.GlobalDBSteward.NewDatabase.Schemas { - oldSchema := lib.GlobalDBSteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) - return diffData(l, ofs, oldSchema, newSchema) + for _, newSchema := range d.ops.dbsteward.NewDatabase.Schemas { + oldSchema := d.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) + return diffData(d.ops, ofs, oldSchema, newSchema) } } return nil @@ -505,61 +515,55 @@ func (d *diff) DropSchemaSQL(s *ir.Schema) ([]output.ToSql, error) { // CreateSchemaSQL this implementation is a bit hacky as it's a // transitional step as I factor away global variables func (d *diff) CreateSchemaSQL(s *ir.Schema) ([]output.ToSql, error) { - return GlobalSchema.GetCreationSql(s) -} - -func (diff *diff) Logger() *slog.Logger { - // Hack to work around instantion order issues - return lib.GlobalDBSteward.Logger() + return GlobalSchema.GetCreationSql(d.ops.dbsteward, s) } func (diff *diff) DiffDoc(oldFile, newFile string, oldDoc, newDoc *ir.Definition, upgradePrefix string) error { - dbsteward := lib.GlobalDBSteward timestamp := time.Now().Format(time.RFC1123Z) oldSetNewSet := fmt.Sprintf("-- Old definition: %s\n-- New definition %s\n", oldFile, newFile) var stage1, stage2, stage3, stage4 output.OutputFileSegmenter quoter := diff.Quoter() - - if dbsteward.SingleStageUpgrade { + logger := diff.ops.dbsteward.Logger() + if diff.ops.dbsteward.SingleStageUpgrade { fileName := upgradePrefix + "_single_stage.sql" file, err := os.Create(fileName) if err != nil { return fmt.Errorf("failed to open %s for write: %w", fileName, err) } - stage1 = output.NewOutputFileSegmenterToFile(diff.Logger(), quoter, fileName, 1, file, fileName, dbsteward.OutputFileStatementLimit) + stage1 = output.NewOutputFileSegmenterToFile(logger, quoter, fileName, 1, file, fileName, diff.ops.dbsteward.OutputFileStatementLimit) stage1.SetHeader(sql.NewComment("DBsteward single stage upgrade changes - generated %s\n%s", timestamp, oldSetNewSet)) defer stage1.Close() stage2 = stage1 stage3 = stage1 stage4 = stage1 } else { - stage1 = output.NewOutputFileSegmenter(diff.Logger(), quoter, upgradePrefix+"_stage1_schema", 1, dbsteward.OutputFileStatementLimit) + stage1 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage1_schema", 1, diff.ops.dbsteward.OutputFileStatementLimit) stage1.SetHeader(sql.NewComment("DBSteward stage 1 structure additions and modifications - generated %s\n%s", timestamp, oldSetNewSet)) defer stage1.Close() - stage2 = output.NewOutputFileSegmenter(diff.Logger(), quoter, upgradePrefix+"_stage2_data", 1, dbsteward.OutputFileStatementLimit) + stage2 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage2_data", 1, diff.ops.dbsteward.OutputFileStatementLimit) stage2.SetHeader(sql.NewComment("DBSteward stage 2 data definitions removed - generated %s\n%s", timestamp, oldSetNewSet)) defer stage2.Close() - stage3 = output.NewOutputFileSegmenter(diff.Logger(), quoter, upgradePrefix+"_stage3_schema", 1, dbsteward.OutputFileStatementLimit) + stage3 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage3_schema", 1, diff.ops.dbsteward.OutputFileStatementLimit) stage3.SetHeader(sql.NewComment("DBSteward stage 3 structure changes, constraints, and removals - generated %s\n%s", timestamp, oldSetNewSet)) defer stage3.Close() - stage4 = output.NewOutputFileSegmenter(diff.Logger(), quoter, upgradePrefix+"_stage4_data", 1, dbsteward.OutputFileStatementLimit) + stage4 = output.NewOutputFileSegmenter(logger, quoter, upgradePrefix+"_stage4_data", 1, diff.ops.dbsteward.OutputFileStatementLimit) stage4.SetHeader(sql.NewComment("DBSteward stage 4 data definition changes and additions - generated %s\n%s", timestamp, oldSetNewSet)) defer stage4.Close() } - dbsteward.OldDatabase = oldDoc - dbsteward.NewDatabase = newDoc + diff.ops.dbsteward.OldDatabase = oldDoc + diff.ops.dbsteward.NewDatabase = newDoc return diff.DiffDocWork(stage1, stage2, stage3, stage4) } func (diff *diff) DropOldSchemas(ofs output.OutputFileSegmenter) { // TODO(feat) support oldname following? - for _, oldSchema := range lib.GlobalDBSteward.OldDatabase.Schemas { - if lib.GlobalDBSteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) == nil { - diff.Logger().Info(fmt.Sprintf("Drop old schema: %s", oldSchema.Name)) + for _, oldSchema := range diff.ops.dbsteward.OldDatabase.Schemas { + if diff.ops.dbsteward.NewDatabase.TryGetSchemaNamed(oldSchema.Name) == nil { + diff.ops.dbsteward.Logger().Info(fmt.Sprintf("Drop old schema: %s", oldSchema.Name)) ofs.MustWriteSql(diff.DropSchemaSQL(oldSchema)) } } @@ -567,9 +571,9 @@ func (diff *diff) DropOldSchemas(ofs output.OutputFileSegmenter) { func (diff *diff) CreateNewSchemas(ofs output.OutputFileSegmenter) error { // TODO(feat) support oldname following? - for _, newSchema := range lib.GlobalDBSteward.NewDatabase.Schemas { - if lib.GlobalDBSteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) == nil { - diff.Logger().Info(fmt.Sprintf("Create new schema: %s", newSchema.Name)) + for _, newSchema := range diff.ops.dbsteward.NewDatabase.Schemas { + if diff.ops.dbsteward.OldDatabase.TryGetSchemaNamed(newSchema.Name) == nil { + diff.ops.dbsteward.Logger().Info(fmt.Sprintf("Create new schema: %s", newSchema.Name)) ofs.MustWriteSql(diff.CreateSchemaSQL(newSchema)) } } diff --git a/lib/format/pgsql8/diff_constraints.go b/lib/format/pgsql8/diff_constraints.go index 36b8e18..eebd4ff 100644 --- a/lib/format/pgsql8/diff_constraints.go +++ b/lib/format/pgsql8/diff_constraints.go @@ -10,25 +10,25 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -func createConstraints(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) { +func createConstraints(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) { for _, newTable := range newSchema.Tables { var oldTable *ir.Table if oldSchema != nil { // TODO(feat) what about renames? oldTable = oldSchema.TryGetTableNamed(newTable.Name) } - createConstraintsTable(l, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) + createConstraintsTable(dbs, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) } } -func createConstraintsTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { - isRenamed, err := lib.GlobalDBSteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) +func createConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { + isRenamed, err := dbs.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while checking if table was renamed: %w", err) } if isRenamed { // remove all constraints and recreate with new table name conventions - constraints, err := getTableConstraints(lib.GlobalDBSteward.OldDatabase, oldSchema, oldTable, constraintType) + constraints, err := getTableConstraints(dbs.OldDatabase, oldSchema, oldTable, constraintType) if err != nil { return err } @@ -42,7 +42,7 @@ func createConstraintsTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldS } // add all still-defined constraints back and any new ones to the table - constraints, err = getTableConstraints(lib.GlobalDBSteward.NewDatabase, newSchema, newTable, constraintType) + constraints, err = getTableConstraints(dbs.NewDatabase, newSchema, newTable, constraintType) if err != nil { return err } @@ -52,7 +52,7 @@ func createConstraintsTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldS return nil } - constraints, err := getNewConstraints(l, oldSchema, oldTable, newSchema, newTable, constraintType) + constraints, err := getNewConstraints(dbs, oldSchema, oldTable, newSchema, newTable, constraintType) if err != nil { return err } @@ -62,14 +62,14 @@ func createConstraintsTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldS return nil } -func dropConstraints(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) error { +func dropConstraints(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, constraintType sql99.ConstraintType) error { for _, newTable := range newSchema.Tables { var oldTable *ir.Table if oldSchema != nil { // TODO(feat) what about renames? oldTable = oldSchema.TryGetTableNamed(newTable.Name) } - err := dropConstraintsTable(l, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) + err := dropConstraintsTable(dbs, ofs, oldSchema, oldTable, newSchema, newTable, constraintType) if err != nil { return err } @@ -77,8 +77,8 @@ func dropConstraints(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, return nil } -func dropConstraintsTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { - constraints, err := getOldConstraints(l, oldSchema, oldTable, newSchema, newTable, constraintType) +func dropConstraintsTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) error { + constraints, err := getOldConstraints(dbs, oldSchema, oldTable, newSchema, newTable, constraintType) if err != nil { return err } @@ -88,11 +88,11 @@ func dropConstraintsTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldSch return nil } -func getOldConstraints(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { +func getOldConstraints(dbs *lib.DBSteward, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { out := []*sql99.TableConstraint{} if newTable != nil && oldTable != nil { - oldDb := lib.GlobalDBSteward.OldDatabase - newDb := lib.GlobalDBSteward.NewDatabase + oldDb := dbs.OldDatabase + newDb := dbs.NewDatabase constraints, err := getTableConstraints(oldDb, oldSchema, oldTable, constraintType) if err != nil { return nil, err @@ -106,11 +106,11 @@ func getOldConstraints(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, out = append(out, oldConstraint) continue } - oldConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(l, newDb, oldConstraint) + oldConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(dbs, newDb, oldConstraint) if err != nil { return nil, err } - newConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(l, newDb, newConstraint) + newConstraintWithRenamedTable, err := constraintDependsOnRenamedTable(dbs, newDb, newConstraint) if err != nil { return nil, err } @@ -122,11 +122,11 @@ func getOldConstraints(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, return out, nil } -func getNewConstraints(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { +func getNewConstraints(dbs *lib.DBSteward, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) ([]*sql99.TableConstraint, error) { out := []*sql99.TableConstraint{} if newTable != nil { - oldDb := lib.GlobalDBSteward.OldDatabase - newDb := lib.GlobalDBSteward.NewDatabase + oldDb := dbs.OldDatabase + newDb := dbs.NewDatabase newConstraints, err := getTableConstraints(newDb, newSchema, newTable, constraintType) if err != nil { return nil, err @@ -136,7 +136,7 @@ func getNewConstraints(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, if err != nil { return nil, err } - renamedTable, err := constraintDependsOnRenamedTable(l, newDb, newConstraint) + renamedTable, err := constraintDependsOnRenamedTable(dbs, newDb, newConstraint) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/diff_constraints_test.go b/lib/format/pgsql8/diff_constraints_test.go index 319c3b9..7984f32 100644 --- a/lib/format/pgsql8/diff_constraints_test.go +++ b/lib/format/pgsql8/diff_constraints_test.go @@ -1,9 +1,9 @@ package pgsql8 import ( - "log/slog" "testing" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/format/sql99" @@ -135,14 +135,17 @@ func TestDiffConstraints_DropCreate_ChangePrimaryKeyNameAndTable(t *testing.T) { newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - ofs := output.NewSegmenter(defaultQuoter(slog.Default())) - differ := newDiff(defaultQuoter(slog.Default())) - setOldNewDocs(differ, oldDoc, newDoc) - err := dropConstraintsTable(slog.Default(), ofs, oldSchema, oldSchema.Tables[0], newSchema, nil, sql99.ConstraintTypePrimaryKey) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ofs := output.NewSegmenter(defaultQuoter(dbs)) + differ := newDiff(NewOperations(dbs).(*Operations), defaultQuoter(dbs)) + setOldNewDocs(dbs, differ, oldDoc, newDoc) + err := dropConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, nil, sql99.ConstraintTypePrimaryKey) if err != nil { t.Fatal(err) } - err = createConstraintsTable(slog.Default(), ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], sql99.ConstraintTypePrimaryKey) + err = createConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], sql99.ConstraintTypePrimaryKey) if err != nil { t.Fatal(err) } @@ -381,14 +384,17 @@ func diffConstraintsTableCommon(t *testing.T, oldSchema, newSchema *ir.Schema, c newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - ofs := output.NewSegmenter(defaultQuoter(slog.Default())) - differ := newDiff(defaultQuoter(slog.Default())) - setOldNewDocs(differ, oldDoc, newDoc) - err := dropConstraintsTable(slog.Default(), ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ofs := output.NewSegmenter(defaultQuoter(dbs)) + differ := newDiff(NewOperations(dbs).(*Operations), defaultQuoter(dbs)) + setOldNewDocs(dbs, differ, oldDoc, newDoc) + err := dropConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) if err != nil { t.Fatal(err) } - err = createConstraintsTable(slog.Default(), ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) + err = createConstraintsTable(dbs, ofs, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0], ctype) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/diff_functions.go b/lib/format/pgsql8/diff_functions.go index feb204a..f86329d 100644 --- a/lib/format/pgsql8/diff_functions.go +++ b/lib/format/pgsql8/diff_functions.go @@ -1,14 +1,13 @@ package pgsql8 import ( - "log/slog" - + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" ) -func diffFunctions(l *slog.Logger, stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { +func diffFunctions(dbs *lib.DBSteward, stage1 output.OutputFileSegmenter, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { // drop functions that no longer exist in stage 3 if oldSchema != nil { for _, oldFunction := range oldSchema.Functions { @@ -22,14 +21,14 @@ func diffFunctions(l *slog.Logger, stage1 output.OutputFileSegmenter, stage3 out for _, newFunction := range newSchema.Functions { oldFunction := oldSchema.TryGetFunctionMatching(newFunction) if oldFunction == nil || !oldFunction.Equals(newFunction, ir.SqlFormatPgsql8) { - create, err := getFunctionCreationSql(l, newSchema, newFunction) + create, err := getFunctionCreationSql(dbs, newSchema, newFunction) if err != nil { return nil } stage1.WriteSql(create...) } else if newFunction.ForceRedefine { stage1.WriteSql(sql.NewComment("Function %s.%s has forceRedefine set to true", newSchema.Name, newFunction.Name)) - create, err := getFunctionCreationSql(l, newSchema, newFunction) + create, err := getFunctionCreationSql(dbs, newSchema, newFunction) if err != nil { return nil } @@ -39,7 +38,7 @@ func diffFunctions(l *slog.Logger, stage1 output.OutputFileSegmenter, stage3 out newReturnType := newSchema.TryGetTypeNamed(newFunction.Returns) if oldReturnType != nil && newReturnType != nil && !oldReturnType.Equals(newReturnType) { stage1.WriteSql(sql.NewComment("Function %s.%s return type %s has changed", newSchema.Name, newFunction.Name, newReturnType.Name)) - create, err := getFunctionCreationSql(l, newSchema, newFunction) + create, err := getFunctionCreationSql(dbs, newSchema, newFunction) if err != nil { return nil } diff --git a/lib/format/pgsql8/diff_languages.go b/lib/format/pgsql8/diff_languages.go index 4fac442..4678112 100644 --- a/lib/format/pgsql8/diff_languages.go +++ b/lib/format/pgsql8/diff_languages.go @@ -1,24 +1,22 @@ package pgsql8 import ( - "log/slog" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/output" ) -func diffLanguages(l *slog.Logger, ofs output.OutputFileSegmenter) error { +func diffLanguages(dbs *lib.DBSteward, ofs output.OutputFileSegmenter) error { // TODO(go,pgsql) this is a different flow than old dbsteward: // we do equality comparison inside these two methods, instead of a separate loop // need to validate that this behavior is still correct - dropLanguages(ofs) - return createLanguages(l, ofs) + dropLanguages(dbs, ofs) + return createLanguages(dbs, ofs) } -func dropLanguages(ofs output.OutputFileSegmenter) { - newDoc := lib.GlobalDBSteward.NewDatabase - oldDoc := lib.GlobalDBSteward.OldDatabase +func dropLanguages(dbs *lib.DBSteward, ofs output.OutputFileSegmenter) { + newDoc := dbs.NewDatabase + oldDoc := dbs.OldDatabase // drop languages that either do not exist in the new schema or have changed if oldDoc != nil { @@ -31,15 +29,15 @@ func dropLanguages(ofs output.OutputFileSegmenter) { } } -func createLanguages(l *slog.Logger, ofs output.OutputFileSegmenter) error { - newDoc := lib.GlobalDBSteward.NewDatabase - oldDoc := lib.GlobalDBSteward.OldDatabase +func createLanguages(dbs *lib.DBSteward, ofs output.OutputFileSegmenter) error { + newDoc := dbs.NewDatabase + oldDoc := dbs.OldDatabase // create languages that either do not exist in the old schema or have changed for _, newLang := range newDoc.Languages { oldLang := oldDoc.TryGetLanguageNamed(newLang.Name) if oldLang == nil || !oldLang.Equals(newLang) { - s, err := getCreateLanguageSql(l, newLang) + s, err := getCreateLanguageSql(dbs, newLang) if err != nil { return err } diff --git a/lib/format/pgsql8/diff_sequences.go b/lib/format/pgsql8/diff_sequences.go index 58b279e..4539ef3 100644 --- a/lib/format/pgsql8/diff_sequences.go +++ b/lib/format/pgsql8/diff_sequences.go @@ -1,14 +1,13 @@ package pgsql8 import ( - "log/slog" - + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" ) -func diffSequences(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { +func diffSequences(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { // drop old sequences if oldSchema != nil { for _, oldSeq := range oldSchema.Sequences { @@ -21,7 +20,7 @@ func diffSequences(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema *ir for _, newSeq := range newSchema.Sequences { oldSeq := oldSchema.TryGetSequenceNamed(newSeq.Name) if oldSeq == nil { - sql, err := getCreateSequenceSql(l, newSchema, newSeq) + sql, err := getCreateSequenceSql(dbs, newSchema, newSeq) if err != nil { return err } diff --git a/lib/format/pgsql8/diff_tables.go b/lib/format/pgsql8/diff_tables.go index 10d611c..031550c 100644 --- a/lib/format/pgsql8/diff_tables.go +++ b/lib/format/pgsql8/diff_tables.go @@ -16,7 +16,7 @@ import ( // TODO(go,core) lift much of this up to sql99 // applies transformations to tables that exist in both old and new -func diffTables(l *slog.Logger, stage1, stage3 output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { +func diffTables(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { // note: old dbsteward called create_tables here, but because we split out DiffTable, we can't call it both places, // so callers were updated to call createTables or CreateTable just before calling DiffTables or DiffTable, respectively @@ -26,11 +26,11 @@ func diffTables(l *slog.Logger, stage1, stage3 output.OutputFileSegmenter, oldSc for _, newTable := range newSchema.Tables { oldTable := oldSchema.TryGetTableNamed(newTable.Name) var err error - oldSchema, oldTable, err = lib.GlobalDBSteward.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) + oldSchema, oldTable, err = dbs.OldDatabase.NewTableName(oldSchema, oldTable, newSchema, newTable) if err != nil { return err } - err = diffTable(l, stage1, stage3, oldSchema, oldTable, newSchema, newTable) + err = diffTable(dbs, stage1, stage3, oldSchema, oldTable, newSchema, newTable) if err != nil { return errors.Wrapf(err, "while diffing table %s.%s", newSchema.Name, newTable.Name) } @@ -38,17 +38,17 @@ func diffTables(l *slog.Logger, stage1, stage3 output.OutputFileSegmenter, oldSc return nil } -func diffTable(l *slog.Logger, stage1, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { +func diffTable(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { if oldTable == nil || newTable == nil { // create and drop are handled elsewhere return nil } - err := updateTableOptions(l, stage1, oldTable, newSchema, newTable) + err := updateTableOptions(dbs.Logger(), stage1, oldTable, newSchema, newTable) if err != nil { return errors.Wrap(err, "while diffing table options") } - err = updateTableColumns(l, stage1, stage3, oldTable, newSchema, newTable) + err = updateTableColumns(dbs, stage1, stage3, oldTable, newSchema, newTable) if err != nil { return errors.Wrap(err, "while diffing table columns") } @@ -163,22 +163,22 @@ type updateTableColumnsAgg struct { after3 []output.ToSql } -func updateTableColumns(l *slog.Logger, stage1, stage3 output.OutputFileSegmenter, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { +func updateTableColumns(dbs *lib.DBSteward, stage1, stage3 output.OutputFileSegmenter, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { agg := &updateTableColumnsAgg{} // TODO(go,pgsql) old dbsteward interleaved commands into a single list, and output in the same order // meaning that a BEFORE3 could be output before a BEFORE1 in a single-stage upgrade. in this implementation, // _all_ BEFORE1s are printed before BEFORE3s. Double check that this doesn't break anything. - err := addDropTableColumns(agg, oldTable, newTable) + err := addDropTableColumns(dbs, agg, oldTable, newTable) if err != nil { return err } - err = addCreateTableColumns(l, agg, oldTable, newSchema, newTable) + err = addCreateTableColumns(dbs, agg, oldTable, newSchema, newTable) if err != nil { return err } - err = addModifyTableColumns(l, agg, oldTable, newSchema, newTable) + err = addModifyTableColumns(dbs, agg, oldTable, newSchema, newTable) if err != nil { return err } @@ -212,7 +212,7 @@ func updateTableColumns(l *slog.Logger, stage1, stage3 output.OutputFileSegmente return nil } -func addDropTableColumns(agg *updateTableColumnsAgg, oldTable, newTable *ir.Table) error { +func addDropTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTable, newTable *ir.Table) error { for _, oldColumn := range oldTable.Columns { if newTable.TryGetColumnNamed(oldColumn.Name) != nil { // new column exists, not dropping it @@ -220,7 +220,7 @@ func addDropTableColumns(agg *updateTableColumnsAgg, oldTable, newTable *ir.Tabl } renamedColumn := newTable.TryGetColumnOldNamed(oldColumn.Name) - if !lib.GlobalDBSteward.IgnoreOldNames && renamedColumn != nil { + if !dbs.IgnoreOldNames && renamedColumn != nil { agg.after3 = append(agg.after3, sql.NewComment( "%s DROP COLUMN %s omitted: new column %s indicates it is the replacement for %s", oldTable.Name, oldColumn.Name, renamedColumn.Name, oldColumn.Name, @@ -232,10 +232,10 @@ func addDropTableColumns(agg *updateTableColumnsAgg, oldTable, newTable *ir.Tabl return nil } -func addCreateTableColumns(l *slog.Logger, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { +func addCreateTableColumns(dbs *lib.DBSteward, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { // note that postgres treats identifiers as case-sensitive when quoted // TODO(go,3) find a way to generalize/streamline this - caseSensitive := lib.GlobalDBSteward.QuoteAllNames || lib.GlobalDBSteward.QuoteColumnNames + caseSensitive := dbs.QuoteAllNames || dbs.QuoteColumnNames for _, newColumn := range newTable.Columns { if oldTable.TryGetColumnNamedCase(newColumn.Name, caseSensitive) != nil { @@ -243,7 +243,7 @@ func addCreateTableColumns(l *slog.Logger, agg *updateTableColumnsAgg, oldTable continue } - isRenamed, err := isRenamedColumn(l, oldTable, newTable, newColumn) + isRenamed, err := isRenamedColumn(dbs, oldTable, newTable, newColumn) if err != nil { return errors.Wrapf(err, "while adding new table columns") } @@ -260,7 +260,7 @@ func addCreateTableColumns(l *slog.Logger, agg *updateTableColumnsAgg, oldTable // notice $include_null_definition is false // this is because ADD COLUMNs with NOT NULL will fail when there are existing rows - colDef, err := getFullColumnDefinition(l, lib.GlobalDBSteward.NewDatabase, newSchema, newTable, newColumn, false, true) + colDef, err := getFullColumnDefinition(dbs.Logger(), dbs.NewDatabase, newSchema, newTable, newColumn, false, true) if err != nil { return err } @@ -334,9 +334,7 @@ func addCreateTableColumns(l *slog.Logger, agg *updateTableColumnsAgg, oldTable return nil } -func addModifyTableColumns(l *slog.Logger, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { - dbsteward := lib.GlobalDBSteward - +func addModifyTableColumns(dbsteward *lib.DBSteward, agg *updateTableColumnsAgg, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { // note that postgres treats identifiers as case-sensitive when quoted // TODO(go,3) find a way to generalize/streamline this caseSensitive := dbsteward.QuoteAllNames || dbsteward.QuoteColumnNames @@ -347,7 +345,7 @@ func addModifyTableColumns(l *slog.Logger, agg *updateTableColumnsAgg, oldTable // old table does not contain column, CREATE handled by addCreateTableColumns continue } - isRenamed, err := isRenamedColumn(l, oldTable, newTable, newColumn) + isRenamed, err := isRenamedColumn(dbsteward, oldTable, newTable, newColumn) if err != nil { return errors.Wrapf(err, "while diffing table columns") } @@ -358,11 +356,11 @@ func addModifyTableColumns(l *slog.Logger, agg *updateTableColumnsAgg, oldTable } // TODO(go,pgsql) orig code calls (oldDB, *newSchema*, oldTable, oldColumn) but that seems wrong, need to validate this - oldType, err := getColumnType(l, dbsteward.OldDatabase, newSchema, oldTable, oldColumn) + oldType, err := getColumnType(dbsteward.Logger(), dbsteward.OldDatabase, newSchema, oldTable, oldColumn) if err != nil { return err } - newType, err := getColumnType(l, dbsteward.NewDatabase, newSchema, newTable, newColumn) + newType, err := getColumnType(dbsteward.Logger(), dbsteward.NewDatabase, newSchema, newTable, newColumn) if err != nil { return err } @@ -452,7 +450,8 @@ func checkPartition(oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Sche return errors.Errorf("Changing a parititioned table's name is not supported: %s.%s", oldSchema.Name, oldTable.Name) } // XmlParser has the rest of this knowledge - return GlobalXmlParser.CheckPartitionChange(oldSchema, oldTable, newSchema, newTable) + xmlParser := NewXmlParser(quoter) + return xmlParser.CheckPartitionChange(oldSchema, oldTable, newSchema, newTable) } func checkInherits(oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) error { @@ -489,8 +488,7 @@ func addAlterStatistics(stage1 output.OutputFileSegmenter, oldTable *ir.Table, n return nil } -func isRenamedColumn(l *slog.Logger, oldTable, newTable *ir.Table, newColumn *ir.Column) (bool, error) { - dbsteward := lib.GlobalDBSteward +func isRenamedColumn(dbsteward *lib.DBSteward, oldTable, newTable *ir.Table, newColumn *ir.Column) (bool, error) { if dbsteward.IgnoreOldNames { return false, nil } @@ -525,19 +523,19 @@ func isRenamedColumn(l *slog.Logger, oldTable, newTable *ir.Table, newColumn *ir // newColumn.OldColumnName exists in old schema // newColumn.OldColumnName does not exist in new schema if oldTable.TryGetColumnNamedCase(newColumn.OldColumnName, caseSensitive) != nil && newTable.TryGetColumnNamedCase(newColumn.OldColumnName, caseSensitive) == nil { - l.Info(fmt.Sprintf("Column %s.%s used to be called %s", newTable.Name, newColumn.Name, newColumn.OldColumnName)) + dbsteward.Logger().Info(fmt.Sprintf("Column %s.%s used to be called %s", newTable.Name, newColumn.Name, newColumn.OldColumnName)) return true, nil } return false, nil } -func createTables(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { +func createTables(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { if newSchema == nil { // if the new schema is nil, there's no tables to create return nil } for _, newTable := range newSchema.Tables { - err := createTable(l, ofs, oldSchema, newSchema, newTable) + err := createTable(dbs, ofs, oldSchema, newSchema, newTable) if err != nil { return err } @@ -545,8 +543,8 @@ func createTables(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, new return nil } -func createTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, newTable *ir.Table) error { - l = l.With( +func createTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema, newTable *ir.Table) error { + l := dbs.Logger().With( slog.String("function", "createTable()"), slog.String("old schema", oldSchema.Name), slog.String("new schema", newSchema.Name), @@ -563,15 +561,15 @@ func createTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newS return nil } - isRenamed, err := lib.GlobalDBSteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := dbs.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return err } if isRenamed { l.Debug("table renamed") // this is a renamed table, so rename it instead of creating a new one - oldTableSchema := lib.GlobalDBSteward.OldDatabase.GetOldTableSchema(newSchema, newTable) - oldTable := lib.GlobalDBSteward.OldDatabase.GetOldTable(newSchema, newTable) + oldTableSchema := dbs.OldDatabase.GetOldTableSchema(newSchema, newTable) + oldTable := dbs.OldDatabase.GetOldTable(newSchema, newTable) // ALTER TABLE ... RENAME TO does not accept schema qualifiers ... oldRef := sql.TableRef{Schema: oldTableSchema.Name, Table: oldTable.Name} @@ -594,7 +592,7 @@ func createTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newS } } else { l.Debug("table not renamed") - createTableSQL, err := getCreateTableSql(l, newSchema, newTable) + createTableSQL, err := getCreateTableSql(dbs, newSchema, newTable) if err != nil { return err } @@ -610,23 +608,23 @@ func createTable(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newS return nil } -func dropTables(ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) { +func dropTables(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) { // if newSchema is nil, we'll have already dropped all the tables in it if oldSchema != nil && newSchema != nil { for _, oldTable := range oldSchema.Tables { - dropTable(ofs, oldSchema, oldTable, newSchema) + dropTable(dbs, ofs, oldSchema, oldTable, newSchema) } } } -func dropTable(ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema) { +func dropTable(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema) { newTable := newSchema.TryGetTableNamed(oldTable.Name) if newTable != nil { // table exists, nothing to do return } - if !lib.GlobalDBSteward.IgnoreOldNames { - renamedRef := lib.GlobalDBSteward.NewDatabase.TryGetTableFormerlyKnownAs(oldSchema, oldTable) + if !dbs.IgnoreOldNames { + renamedRef := dbs.NewDatabase.TryGetTableFormerlyKnownAs(oldSchema, oldTable) if renamedRef != nil { ofs.WriteSql(sql.NewComment("DROP TABLE %s.%s omitted: new table %s indicates it is her replacement", oldSchema.Name, oldTable.Name, renamedRef)) return @@ -652,24 +650,24 @@ func diffClustersTable(ofs output.OutputFileSegmenter, oldTable *ir.Table, newSc } } -func diffData(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { +func diffData(ops *Operations, ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) error { for _, newTable := range newSchema.Tables { - isRenamed, err := lib.GlobalDBSteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) + isRenamed, err := ops.dbsteward.OldDatabase.IsRenamedTable(slog.Default(), newSchema, newTable) if err != nil { return fmt.Errorf("while diffing data: %w", err) } if isRenamed { // if the table was renamed, get old definition pointers, diff that - oldSchema := lib.GlobalDBSteward.OldDatabase.GetOldTableSchema(newSchema, newTable) - oldTable := lib.GlobalDBSteward.OldDatabase.GetOldTable(newSchema, newTable) - s, err := getCreateDataSql(l, oldSchema, oldTable, newSchema, newTable) + oldSchema := ops.dbsteward.OldDatabase.GetOldTableSchema(newSchema, newTable) + oldTable := ops.dbsteward.OldDatabase.GetOldTable(newSchema, newTable) + s, err := getCreateDataSql(ops, oldSchema, oldTable, newSchema, newTable) if err != nil { return err } ofs.WriteSql(s...) } else { oldTable := oldSchema.TryGetTableNamed(newTable.Name) - s, err := getCreateDataSql(l, oldSchema, oldTable, newSchema, newTable) + s, err := getCreateDataSql(ops, oldSchema, oldTable, newSchema, newTable) if err != nil { return err } @@ -679,13 +677,13 @@ func diffData(l *slog.Logger, ofs output.OutputFileSegmenter, oldSchema, newSche return nil } -func getCreateDataSql(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) ([]output.ToSql, error) { +func getCreateDataSql(ops *Operations, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) ([]output.ToSql, error) { newRows, updatedRows := getNewAndChangedRows(oldTable, newTable) // cut back on allocations - we know that there's going to be _at least_ one statement for every new and updated row, and likely 1 for the serial start out := make([]output.ToSql, 0, len(newRows)+len(updatedRows)+1) for _, updatedRow := range updatedRows { - update, err := buildDataUpdate(l, newSchema, newTable, updatedRow) + update, err := buildDataUpdate(ops, newSchema, newTable, updatedRow) if err != nil { return nil, err } @@ -693,7 +691,7 @@ func getCreateDataSql(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, } for _, newRow := range newRows { // TODO(go,3) batch inserts - insert, err := buildDataInsert(l, newSchema, newTable, newRow) + insert, err := buildDataInsert(ops, newSchema, newTable, newRow) if err != nil { return nil, err } @@ -713,12 +711,12 @@ func getCreateDataSql(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, return out, nil } -func getDeleteDataSql(l *slog.Logger, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) ([]output.ToSql, error) { +func getDeleteDataSql(ops *Operations, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) ([]output.ToSql, error) { oldRows := getOldRows(oldTable, newTable) out := make([]output.ToSql, len(oldRows)) var err error for i, oldRow := range oldRows { - out[i], err = buildDataDelete(l, oldSchema, oldTable, oldRow) + out[i], err = buildDataDelete(ops, oldSchema, oldTable, oldRow) if err != nil { return nil, err } @@ -805,13 +803,13 @@ func getOldRows(oldTable, newTable *ir.Table) []*ir.DataRow { return oldRows } -func buildDataInsert(l *slog.Logger, schema *ir.Schema, table *ir.Table, row *ir.DataRow) (output.ToSql, error) { +func buildDataInsert(ops *Operations, schema *ir.Schema, table *ir.Table, row *ir.DataRow) (output.ToSql, error) { util.Assert(table.Rows != nil, "table.Rows should not be nil when calling buildDataInsert") util.Assert(!row.Delete, "do not call buildDataInsert for a row marked for deletion") values := make([]sql.ToSqlValue, len(row.Columns)) var err error for i, col := range table.Rows.Columns { - values[i], err = columnValueDefault(l, schema, table, col, row.Columns[i]) + values[i], err = ops.columnValueDefault(ops.logger, schema, table, col, row.Columns[i]) if err != nil { return nil, err } @@ -823,7 +821,7 @@ func buildDataInsert(l *slog.Logger, schema *ir.Schema, table *ir.Table, row *ir }, nil } -func buildDataUpdate(l *slog.Logger, schema *ir.Schema, table *ir.Table, change *changedRow) (output.ToSql, error) { +func buildDataUpdate(ops *Operations, schema *ir.Schema, table *ir.Table, change *changedRow) (output.ToSql, error) { // TODO(feat) deal with column renames util.Assert(table.Rows != nil, "table.Rows should not be nil when calling buildDataUpdate") util.Assert(!change.newRow.Delete, "do not call buildDataUpdate for a row marked for deletion") @@ -836,7 +834,7 @@ func buildDataUpdate(l *slog.Logger, schema *ir.Schema, table *ir.Table, change oldColIdx := util.IStrsIndex(change.oldCols, newColName) if oldColIdx < 0 || !change.oldRow.Columns[oldColIdx].Equals(newCol) { updateCols = append(updateCols, newColName) - cvd, err := columnValueDefault(l, schema, table, newColName, newCol) + cvd, err := ops.columnValueDefault(ops.logger, schema, table, newColName, newCol) if err != nil { return nil, err } @@ -848,7 +846,7 @@ func buildDataUpdate(l *slog.Logger, schema *ir.Schema, table *ir.Table, change pkColMap := table.Rows.GetColMapKeys(change.newRow, table.PrimaryKey) for name, col := range pkColMap { // TODO(go,pgsql) orig code in dbx::primary_key_expression uses `format::value_escape`, but that doesn't account for null, empty, sql, etc - cvd, err := columnValueDefault(l, schema, table, name, col) + cvd, err := ops.columnValueDefault(ops.logger, schema, table, name, col) if err != nil { return nil, err } @@ -864,12 +862,12 @@ func buildDataUpdate(l *slog.Logger, schema *ir.Schema, table *ir.Table, change }, nil } -func buildDataDelete(l *slog.Logger, schema *ir.Schema, table *ir.Table, row *ir.DataRow) (output.ToSql, error) { +func buildDataDelete(ops *Operations, schema *ir.Schema, table *ir.Table, row *ir.DataRow) (output.ToSql, error) { keyVals := []sql.ToSqlValue{} pkColMap := table.Rows.GetColMapKeys(row, table.PrimaryKey) for name, col := range pkColMap { // TODO(go,pgsql) orig code in dbx::primary_key_expression uses `format::value_escape`, but that doesn't account for null, empty, sql, etc - val, err := columnValueDefault(l, schema, table, name, col) + val, err := ops.columnValueDefault(ops.logger, schema, table, name, col) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/diff_tables_escape_char_test.go b/lib/format/pgsql8/diff_tables_escape_char_test.go index 95e3ab0..3b0e23a 100644 --- a/lib/format/pgsql8/diff_tables_escape_char_test.go +++ b/lib/format/pgsql8/diff_tables_escape_char_test.go @@ -1,9 +1,9 @@ package pgsql8 import ( - "log/slog" "testing" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/output" @@ -19,7 +19,7 @@ func TestDiffTables_GetDataSql_EscapeCharacters(t *testing.T) { schema := &ir.Schema{ Name: "public", Tables: []*ir.Table{ - &ir.Table{ + { Name: "i_test", PrimaryKey: []string{"pk"}, Columns: []*ir.Column{ @@ -29,7 +29,7 @@ func TestDiffTables_GetDataSql_EscapeCharacters(t *testing.T) { Rows: &ir.DataRows{ Columns: []string{"pk", "col1"}, Rows: []*ir.DataRow{ - &ir.DataRow{ + { Columns: []*ir.DataCol{ {Text: "1"}, {Text: "hi"}, @@ -40,18 +40,24 @@ func TestDiffTables_GetDataSql_EscapeCharacters(t *testing.T) { }, }, } - - ddl, err := getCreateDataSql(slog.Default(), nil, nil, schema, schema.Tables[0]) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + dbs.NewDatabase = &ir.Definition{ + Schemas: []*ir.Schema{schema}, + } + ops := NewOperations(dbs).(*Operations) + ddl, err := getCreateDataSql(ops, nil, nil, schema, schema.Tables[0]) if err != nil { t.Fatal(err) } assert.Equal(t, []output.ToSql{ &sql.DataInsert{ - Table: sql.TableRef{"public", "i_test"}, + Table: sql.TableRef{Schema: "public", Table: "i_test"}, Columns: []string{"pk", "col1"}, Values: []sql.ToSqlValue{ - &sql.TypedValue{"int", "1", false}, - &sql.TypedValue{"char(10)", "hi", false}, + &sql.TypedValue{Type: "int", Value: "1", IsNull: false}, + &sql.TypedValue{Type: "char(10)", Value: "hi", IsNull: false}, }, }, }, ddl) diff --git a/lib/format/pgsql8/diff_tables_test.go b/lib/format/pgsql8/diff_tables_test.go index 1346d35..12cccd7 100644 --- a/lib/format/pgsql8/diff_tables_test.go +++ b/lib/format/pgsql8/diff_tables_test.go @@ -1,7 +1,6 @@ package pgsql8 import ( - "log/slog" "strings" "testing" @@ -14,7 +13,6 @@ import ( ) func TestDiffTables_DiffTables_ColumnCaseChange(t *testing.T) { - defer resetGlobalDBSteward() lower := &ir.Schema{ Name: "test0", Tables: []*ir.Table{ @@ -51,19 +49,23 @@ func TestDiffTables_DiffTables_ColumnCaseChange(t *testing.T) { }, } - lib.GlobalDBSteward.IgnoreOldNames = false + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + dbs.IgnoreOldNames = false + ops := NewOperations(dbs).(*Operations) // when quoting is off, a change in case is a no-op - lib.GlobalDBSteward.QuoteAllNames = false - lib.GlobalDBSteward.QuoteColumnNames = false - ddl1, ddl3 := diffTablesCommon(t, lower, upperWithoutOldName) + dbs.QuoteAllNames = false + dbs.QuoteColumnNames = false + ddl1, ddl3 := diffTablesCommon(t, ops, lower, upperWithoutOldName) assert.Empty(t, ddl1) assert.Empty(t, ddl3) // when quoting is on, a change in case results in a rename, if there's an oldname - lib.GlobalDBSteward.QuoteAllNames = true - lib.GlobalDBSteward.QuoteColumnNames = true - ddl1, ddl3 = diffTablesCommon(t, lower, upperWithOldName) + dbs.QuoteAllNames = true + dbs.QuoteColumnNames = true + ddl1, ddl3 = diffTablesCommon(t, ops, lower, upperWithOldName) assert.Equal(t, []output.ToSql{ &sql.ColumnRename{ Column: sql.ColumnRef{Schema: "test0", Table: "table", Column: "column"}, @@ -73,7 +75,7 @@ func TestDiffTables_DiffTables_ColumnCaseChange(t *testing.T) { assert.Empty(t, ddl3) // but, if oldColumnName is not given when doing case sensitive renames, it should error - _, _, err := diffTablesCommonErr(lower, upperWithoutOldName) + _, _, err := diffTablesCommonErr(ops, lower, upperWithoutOldName) if assert.Error(t, err) { assert.Contains(t, strings.ToLower(err.Error()), "ambiguous operation") } @@ -96,8 +98,11 @@ func TestDiffTables_DiffTables_TableOptions_NoChange(t *testing.T) { }, }, } - - ddl1, ddl3 := diffTablesCommon(t, schema, schema) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + ddl1, ddl3 := diffTablesCommon(t, ops, schema, schema) assert.Empty(t, ddl1) assert.Empty(t, ddl3) } @@ -128,8 +133,11 @@ func TestDiffTables_DiffTables_TableOptions_AddWith(t *testing.T) { }, }, } - - ddl1, ddl3 := diffTablesCommon(t, oldSchema, newSchema) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ sql.NewTableAlter( sql.TableRef{"public", "test"}, @@ -176,8 +184,11 @@ func TestDiffTables_DiffTables_TableOptions_AlterWith(t *testing.T) { }, }, } - - ddl1, ddl3 := diffTablesCommon(t, oldSchema, newSchema) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ sql.NewTableAlter( sql.TableRef{"public", "test"}, @@ -229,8 +240,11 @@ func TestDiffTables_DiffTables_TableOptions_AddTablespaceAlterWith(t *testing.T) }, }, } - - ddl1, ddl3 := diffTablesCommon(t, oldSchema, newSchema) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ &sql.TableMoveTablespaceIndexes{ Table: sql.TableRef{"public", "test"}, @@ -287,8 +301,11 @@ func TestDiffTables_DiffTables_TableOptions_DropTablespace(t *testing.T) { }, }, } - - ddl1, ddl3 := diffTablesCommon(t, oldSchema, newSchema) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + ddl1, ddl3 := diffTablesCommon(t, ops, oldSchema, newSchema) assert.Equal(t, []output.ToSql{ &sql.TableResetTablespace{ Table: sql.TableRef{"public", "test"}, @@ -339,12 +356,15 @@ func TestDiffTables_GetDeleteCreateDataSql_AddSerialColumn(t *testing.T) { }, }, } - - delddl, err := getDeleteDataSql(slog.Default(), oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0]) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + delddl, err := getDeleteDataSql(ops, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0]) if err != nil { t.Fatal(err) } - addddl, err := getCreateDataSql(slog.Default(), oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0]) + addddl, err := getCreateDataSql(ops, oldSchema, oldSchema.Tables[0], newSchema, newSchema.Tables[0]) if err != nil { t.Fatal(err) } @@ -360,33 +380,33 @@ func TestDiffTables_GetDeleteCreateDataSql_AddSerialColumn(t *testing.T) { }, addddl) } -func diffTablesCommon(t *testing.T, oldSchema, newSchema *ir.Schema) ([]output.ToSql, []output.ToSql) { - ofs1, ofs3, err := diffTablesCommonErr(oldSchema, newSchema) +func diffTablesCommon(t *testing.T, ops *Operations, oldSchema, newSchema *ir.Schema) ([]output.ToSql, []output.ToSql) { + ofs1, ofs3, err := diffTablesCommonErr(ops, oldSchema, newSchema) if err != nil { t.Fatal(err) } return ofs1, ofs3 } -func diffTablesCommonErr(oldSchema, newSchema *ir.Schema) ([]output.ToSql, []output.ToSql, error) { +func diffTablesCommonErr(ops *Operations, oldSchema, newSchema *ir.Schema) ([]output.ToSql, []output.ToSql, error) { oldDoc := &ir.Definition{ Schemas: []*ir.Schema{oldSchema}, } newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - differ := newDiff(defaultQuoter(slog.Default())) - setOldNewDocs(differ, oldDoc, newDoc) - ofs1 := output.NewAnnotationStrippingSegmenter(defaultQuoter(slog.Default())) - ofs3 := output.NewAnnotationStrippingSegmenter(defaultQuoter(slog.Default())) + differ := newDiff(ops, defaultQuoter(ops.dbsteward)) + setOldNewDocs(ops.dbsteward, differ, oldDoc, newDoc) + ofs1 := output.NewAnnotationStrippingSegmenter(defaultQuoter(ops.dbsteward)) + ofs3 := output.NewAnnotationStrippingSegmenter(defaultQuoter(ops.dbsteward)) // note: v1 only used DiffTables, v2 split into CreateTables+DiffTables - err := createTables(slog.Default(), ofs1, oldSchema, newSchema) + err := createTables(ops.dbsteward, ofs1, oldSchema, newSchema) if err != nil { return ofs1.Body, ofs3.Body, err } - err = diffTables(slog.Default(), ofs1, ofs3, oldSchema, newSchema) + err = diffTables(ops.dbsteward, ofs1, ofs3, oldSchema, newSchema) if err != nil { return ofs1.Body, ofs3.Body, err } diff --git a/lib/format/pgsql8/diff_types.go b/lib/format/pgsql8/diff_types.go index 6815646..144d165 100644 --- a/lib/format/pgsql8/diff_types.go +++ b/lib/format/pgsql8/diff_types.go @@ -2,15 +2,15 @@ package pgsql8 import ( "fmt" - "log/slog" "strings" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" ) -func diffTypes(l *slog.Logger, differ *diff, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { +func diffTypes(dbs *lib.DBSteward, differ *diff, ofs output.OutputFileSegmenter, oldSchema *ir.Schema, newSchema *ir.Schema) error { dropTypes(ofs, oldSchema, newSchema) err := createTypes(ofs, oldSchema, newSchema) if err != nil { @@ -41,7 +41,7 @@ func diffTypes(l *slog.Logger, differ *diff, ofs output.OutputFileSegmenter, old ofs.WriteSql(getFunctionDropSql(oldSchema, oldFunc)...) } - columns, sql, err := alterColumnTypePlaceholder(l, differ, oldType) + columns, sql, err := alterColumnTypePlaceholder(dbs, differ, oldType) if err != nil { return err } @@ -63,7 +63,7 @@ func diffTypes(l *slog.Logger, differ *diff, ofs output.OutputFileSegmenter, old // functions are only recreated if they changed elsewise, so need to create them here for _, newFunc := range GlobalSchema.GetFunctionsDependingOnType(newSchema, newType) { - s, err := getFunctionCreationSql(l, newSchema, newFunc) + s, err := getFunctionCreationSql(dbs, newSchema, newFunc) if err != nil { return err } diff --git a/lib/format/pgsql8/diff_types_domains_test.go b/lib/format/pgsql8/diff_types_domains_test.go index 0a5eed2..65373ce 100644 --- a/lib/format/pgsql8/diff_types_domains_test.go +++ b/lib/format/pgsql8/diff_types_domains_test.go @@ -1,11 +1,11 @@ package pgsql8 import ( - "log/slog" "testing" "github.com/stretchr/testify/assert" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -323,10 +323,14 @@ func diffTypesForTest(t *testing.T, oldSchema, newSchema *ir.Schema) []output.To newDoc := &ir.Definition{ Schemas: []*ir.Schema{newSchema}, } - differ := newDiff(defaultQuoter(slog.Default())) - setOldNewDocs(differ, oldDoc, newDoc) - ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(slog.Default())) - err := diffTypes(slog.Default(), differ, ofs, oldSchema, newSchema) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + differ := newDiff(ops, defaultQuoter(dbs)) + setOldNewDocs(dbs, differ, oldDoc, newDoc) + ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(dbs)) + err := diffTypes(dbs, differ, ofs, oldSchema, newSchema) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/diff_types_test.go b/lib/format/pgsql8/diff_types_test.go index e4750c0..1214302 100644 --- a/lib/format/pgsql8/diff_types_test.go +++ b/lib/format/pgsql8/diff_types_test.go @@ -1,9 +1,9 @@ package pgsql8 import ( - "log/slog" "testing" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -94,9 +94,12 @@ func TestDiffTypes_DiffTypes_RecreateDependentFunctions(t *testing.T) { }, } - ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(slog.Default())) - - err := diffTypes(slog.Default(), newDiff(defaultQuoter(slog.Default())), ofs, oldSchema, newSchema) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + ops := NewOperations(dbs).(*Operations) + ofs := output.NewAnnotationStrippingSegmenter(defaultQuoter(dbs)) + err := diffTypes(dbs, newDiff(ops, defaultQuoter(dbs)), ofs, oldSchema, newSchema) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/diff_views.go b/lib/format/pgsql8/diff_views.go index 5d8cca6..508d2a5 100644 --- a/lib/format/pgsql8/diff_views.go +++ b/lib/format/pgsql8/diff_views.go @@ -10,7 +10,8 @@ import ( // TODO(go,core) lift some of these to sql99 -func createViewsOrdered(l *slog.Logger, ofs output.OutputFileSegmenter, oldDoc *ir.Definition, newDoc *ir.Definition) error { +func createViewsOrdered(dbs *lib.DBSteward, ofs output.OutputFileSegmenter, oldDoc *ir.Definition, newDoc *ir.Definition) error { + l := dbs.Logger() return forEachViewInDepOrder(newDoc, func(newRef ir.ViewRef) error { ll := l.With(slog.String("view", newRef.String())) ll.Debug("consider creating") @@ -23,11 +24,11 @@ func createViewsOrdered(l *slog.Logger, ofs output.OutputFileSegmenter, oldDoc * if oldView != nil { ll = ll.With(slog.String("old view", oldView.Name)) } - if shouldCreateView(oldView, newRef.View) { + if shouldCreateView(dbs, oldView, newRef.View) { ll.Debug("shouldCreateView returned true") - s, err := getCreateViewSql(l, newRef.Schema, newRef.View) + s, err := getCreateViewSql(dbs, newRef.Schema, newRef.View) for _, s1 := range s { - ll.Debug(s1.ToSql(defaultQuoter(ll))) + ll.Debug(s1.ToSql(defaultQuoter(dbs))) } if err != nil { return err @@ -40,8 +41,8 @@ func createViewsOrdered(l *slog.Logger, ofs output.OutputFileSegmenter, oldDoc * }) } -func shouldCreateView(oldView, newView *ir.View) bool { - return oldView == nil || lib.GlobalDBSteward.AlwaysRecreateViews || !oldView.Equals(newView, ir.SqlFormatPgsql8) +func shouldCreateView(dbs *lib.DBSteward, oldView, newView *ir.View) bool { + return oldView == nil || dbs.AlwaysRecreateViews || !oldView.Equals(newView, ir.SqlFormatPgsql8) } func dropViewsOrdered(ofs output.OutputFileSegmenter, oldDoc *ir.Definition, newDoc *ir.Definition) error { diff --git a/lib/format/pgsql8/diff_views_test.go b/lib/format/pgsql8/diff_views_test.go index 11f31e3..6b2c7ef 100644 --- a/lib/format/pgsql8/diff_views_test.go +++ b/lib/format/pgsql8/diff_views_test.go @@ -1,9 +1,9 @@ package pgsql8 import ( - "log/slog" "testing" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -53,9 +53,12 @@ func newSingleView() *ir.Definition { } func TestCreateViewsOrdered(t *testing.T) { - q := defaultQuoter(slog.Default()) + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + q := defaultQuoter(dbs) ofs := output.NewAnnotationStrippingSegmenter(q) - err := createViewsOrdered(slog.Default(), ofs, oldSingleView(), newSingleView()) + err := createViewsOrdered(dbs, ofs, oldSingleView(), newSingleView()) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/function.go b/lib/format/pgsql8/function.go index e3dfa9b..d0dac8c 100644 --- a/lib/format/pgsql8/function.go +++ b/lib/format/pgsql8/function.go @@ -2,7 +2,6 @@ package pgsql8 import ( "fmt" - "log/slog" "strings" "github.com/dbsteward/dbsteward/lib" @@ -35,7 +34,7 @@ func functionDefinitionReferencesTable(definition *ir.FunctionDefinition) *lib.Q return &parsed } -func getFunctionCreationSql(l *slog.Logger, schema *ir.Schema, function *ir.Function) ([]output.ToSql, error) { +func getFunctionCreationSql(dbs *lib.DBSteward, schema *ir.Schema, function *ir.Function) ([]output.ToSql, error) { ref := sql.FunctionRef{Schema: schema.Name, Function: function.Name, Params: function.ParamSigs()} def := function.TryGetDefinition(ir.SqlFormatPgsql8) out := []output.ToSql{ @@ -50,7 +49,7 @@ func getFunctionCreationSql(l *slog.Logger, schema *ir.Schema, function *ir.Func } if function.Owner != "" { - role, err := roleEnum(l, lib.GlobalDBSteward.NewDatabase, function.Owner) + role, err := roleEnum(dbs.Logger(), dbs.NewDatabase, function.Owner, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -94,11 +93,11 @@ func normalizeFunctionParameterType(paramType string) string { return paramType } -func getFunctionGrantSql(l *slog.Logger, schema *ir.Schema, fn *ir.Function, grant *ir.Grant) ([]output.ToSql, error) { +func getFunctionGrantSql(dbs *lib.DBSteward, schema *ir.Schema, fn *ir.Function, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(l, lib.GlobalDBSteward.NewDatabase, role) + roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/language.go b/lib/format/pgsql8/language.go index 57039c9..fe2ff4b 100644 --- a/lib/format/pgsql8/language.go +++ b/lib/format/pgsql8/language.go @@ -1,15 +1,13 @@ package pgsql8 import ( - "log/slog" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" ) -func getCreateLanguageSql(l *slog.Logger, lang *ir.Language) ([]output.ToSql, error) { +func getCreateLanguageSql(dbsteward *lib.DBSteward, lang *ir.Language) ([]output.ToSql, error) { out := []output.ToSql{ &sql.LanguageCreate{ Language: lang.Name, @@ -21,7 +19,7 @@ func getCreateLanguageSql(l *slog.Logger, lang *ir.Language) ([]output.ToSql, er } if lang.Owner != "" { - role, err := roleEnum(l, lib.GlobalDBSteward.NewDatabase, lang.Owner) + role, err := roleEnum(dbsteward.Logger(), dbsteward.NewDatabase, lang.Owner, dbsteward.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/oneeighty_test.go b/lib/format/pgsql8/oneeighty_test.go index cc92510..3ea073c 100644 --- a/lib/format/pgsql8/oneeighty_test.go +++ b/lib/format/pgsql8/oneeighty_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/dbsteward/dbsteward/lib" - "github.com/dbsteward/dbsteward/lib/format" "github.com/dbsteward/dbsteward/lib/ir" "github.com/stretchr/testify/assert" ) @@ -29,11 +28,11 @@ func TestOneEighty(t *testing.T) { } defer Teardowndb(t, c, "pg") role := os.Getenv("DB_USER") - lib.GlobalDBSteward = lib.NewDBSteward(format.LookupMap{ + dbs := lib.NewDBSteward(lib.LookupMap{ ir.SqlFormatPgsql8: GlobalLookup, }) - lib.GlobalDBSteward.SqlFormat = ir.SqlFormatPgsql8 - ops := NewOperations().(*Operations) + dbs.SqlFormat = ir.SqlFormatPgsql8 + ops := NewOperations(dbs).(*Operations) statements, err := ops.CreateStatements(ir.FullFeatureSchema(role)) if err != nil { t.Fatal(err) diff --git a/lib/format/pgsql8/operations.go b/lib/format/pgsql8/operations.go index 123bd46..9f25a52 100644 --- a/lib/format/pgsql8/operations.go +++ b/lib/format/pgsql8/operations.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/dbsteward/dbsteward/lib/format" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/format/sql99" "github.com/dbsteward/dbsteward/lib/output" @@ -23,17 +22,16 @@ import ( ) type Operations struct { - *sql99.Operations - logger *slog.Logger - differ *diff + logger *slog.Logger + dbsteward *lib.DBSteward + differ *diff } var quoter output.Quoter -func defaultQuoter(logger *slog.Logger) output.Quoter { - dbsteward := lib.GlobalDBSteward +func defaultQuoter(dbsteward *lib.DBSteward) output.Quoter { return &sql.Quoter{ - Logger: logger, + Logger: dbsteward.Logger(), ShouldQuoteSchemaNames: dbsteward.QuoteAllNames || dbsteward.QuoteSchemaNames, ShouldQuoteTableNames: dbsteward.QuoteAllNames || dbsteward.QuoteTableNames, ShouldQuoteColumnNames: dbsteward.QuoteAllNames || dbsteward.QuoteColumnNames, @@ -45,15 +43,14 @@ func defaultQuoter(logger *slog.Logger) output.Quoter { } } -func NewOperations() format.Operations { - quoter = defaultQuoter(lib.GlobalDBSteward.Logger()) - pgsql := &Operations{ - Operations: sql99.NewOperations(), - logger: lib.GlobalDBSteward.Logger(), - differ: newDiff(quoter), +func NewOperations(dbs *lib.DBSteward) lib.Operations { + quoter = defaultQuoter(dbs) + ops := &Operations{ + logger: dbs.Logger(), + dbsteward: dbs, } - pgsql.Operations.Operations = pgsql - return pgsql + ops.differ = newDiff(ops, quoter) + return ops } func (ops *Operations) GetQuoter() output.Quoter { @@ -71,8 +68,6 @@ func (ops *Operations) CreateStatements(def ir.Definition) ([]output.DDLStatemen } func (ops *Operations) Build(outputPrefix string, dbDoc *ir.Definition) error { - dbsteward := lib.GlobalDBSteward - buildFileName := outputPrefix + "_build.sql" ops.logger.Info(fmt.Sprintf("Building complete file %s", buildFileName)) @@ -81,7 +76,7 @@ func (ops *Operations) Build(outputPrefix string, dbDoc *ir.Definition) error { return fmt.Errorf("failed to open file %s for output: %w", buildFileName, err) } - buildFileOfs := output.NewOutputFileSegmenterToFile(ops.logger, ops.GetQuoter(), buildFileName, 1, buildFile, buildFileName, dbsteward.OutputFileStatementLimit) + buildFileOfs := output.NewOutputFileSegmenterToFile(ops.logger, ops.GetQuoter(), buildFileName, 1, buildFile, buildFileName, ops.dbsteward.OutputFileStatementLimit) err = ops.build(buildFileOfs, dbDoc) if err != nil { return err @@ -91,13 +86,11 @@ func (ops *Operations) Build(outputPrefix string, dbDoc *ir.Definition) error { func (ops *Operations) build(buildFileOfs output.OutputFileSegmenter, dbDoc *ir.Definition) error { // TODO(go,4) can we just consider a build(def) to be diff(null, def)? - // some shortcuts, since we're going to be typing a lot here - dbsteward := lib.GlobalDBSteward - if len(dbsteward.LimitToTables) == 0 { + if len(ops.dbsteward.LimitToTables) == 0 { buildFileOfs.WriteSql(sql.NewComment("full database definition file generated %s\n", time.Now().Format(time.RFC1123Z))) } - if !dbsteward.GenerateSlonik { + if !ops.dbsteward.GenerateSlonik { buildFileOfs.WriteSql(output.NewRawSQL("BEGIN;\n\n")) } @@ -108,12 +101,12 @@ func (ops *Operations) build(buildFileOfs output.OutputFileSegmenter, dbDoc *ir. } // database-specific implementation code refers to dbsteward::$new_database when looking up roles/values/conflicts etc - dbsteward.NewDatabase = dbDoc + ops.dbsteward.NewDatabase = dbDoc // language definitions - if dbsteward.CreateLanguages { + if ops.dbsteward.CreateLanguages { for _, language := range dbDoc.Languages { - s, err := getCreateLanguageSql(ops.logger, language) + s, err := getCreateLanguageSql(ops.dbsteward, language) if err != nil { return err } @@ -166,23 +159,23 @@ outer: ops.logger.Info(setCheckFunctionBodiesInfo) } - if dbsteward.OnlySchemaSql || !dbsteward.OnlyDataSql { + if ops.dbsteward.OnlySchemaSql || !ops.dbsteward.OnlyDataSql { ops.logger.Info("Defining structure") err := ops.buildSchema(dbDoc, buildFileOfs, tableDependency) if err != nil { return err } } - if !dbsteward.OnlySchemaSql || dbsteward.OnlyDataSql { + if !ops.dbsteward.OnlySchemaSql || ops.dbsteward.OnlyDataSql { ops.logger.Info("Defining data inserts") - err = buildData(ops.logger, dbDoc, buildFileOfs, tableDependency) + err = ops.buildData(ops.logger, dbDoc, buildFileOfs, tableDependency) if err != nil { return err } } - dbsteward.NewDatabase = nil + ops.dbsteward.NewDatabase = nil - if !dbsteward.GenerateSlonik { + if !ops.dbsteward.GenerateSlonik { buildFileOfs.WriteSql(output.NewRawSQL("COMMIT;\n\n")) } @@ -230,8 +223,8 @@ func (ops *Operations) Upgrade(l *slog.Logger, oldDoc *ir.Definition, newDoc *ir if err != nil { return nil, fmt.Errorf("new document: %w", err) } - lib.GlobalDBSteward.OldDatabase = oldDoc - lib.GlobalDBSteward.NewDatabase = newDoc + ops.dbsteward.OldDatabase = oldDoc + ops.dbsteward.NewDatabase = newDoc stage1 := output.NewSegmenter(ops.GetQuoter()) stage2 := output.NewSegmenter(ops.GetQuoter()) @@ -964,7 +957,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // TODO(go,3) roll this into diffing nil -> doc // schema creation for _, schema := range doc.Schemas { - s, err := GlobalSchema.GetCreationSql(schema) + s, err := GlobalSchema.GetCreationSql(ops.dbsteward, schema) if err != nil { return err } @@ -972,7 +965,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // schema grants for _, grant := range schema.Grants { - s, err := GlobalSchema.GetGrantSql(doc, schema, grant) + s, err := GlobalSchema.GetGrantSql(ops.dbsteward, doc, schema, grant) if err != nil { return err } @@ -997,7 +990,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm includeColumnDefaultNextvalInCreateSql = false for _, table := range schema.Tables { // table definition - s, err := getCreateTableSql(ops.logger, schema, table) + s, err := getCreateTableSql(ops.dbsteward, schema, table) if err != nil { return err } @@ -1011,7 +1004,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // table grants for _, grant := range table.Grants { - s, err := getTableGrantSql(ops.logger, schema, table, grant) + s, err := getTableGrantSql(ops.dbsteward, schema, table, grant) if err != nil { return err } @@ -1023,7 +1016,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // sequences contained in the schema for _, sequence := range schema.Sequences { if sequence.OwnedByColumn == "" { - sql, err := getCreateSequenceSql(ops.logger, schema, sequence) + sql, err := getCreateSequenceSql(ops.dbsteward, schema, sequence) if err != nil { return err } @@ -1036,7 +1029,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // sequence permission grants for _, grant := range sequence.Grants { - s, err := getSequenceGrantSql(ops.logger, schema, sequence, grant) + s, err := getSequenceGrantSql(ops.dbsteward, schema, sequence, grant) if err != nil { return err } @@ -1056,7 +1049,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm for _, schema := range doc.Schemas { for _, function := range schema.Functions { if function.HasDefinition(ir.SqlFormatPgsql8) { - s, err := getFunctionCreationSql(ops.logger, schema, function) + s, err := getFunctionCreationSql(ops.dbsteward, schema, function) if err != nil { return err } @@ -1065,7 +1058,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // they are not included in pg_function::get_creation_sql() for _, grant := range function.Grants { - grant, err := getFunctionGrantSql(ops.logger, schema, function, grant) + grant, err := getFunctionGrantSql(ops.dbsteward, schema, function, grant) if err != nil { return err } @@ -1086,7 +1079,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // define table primary keys before foreign keys so unique requirements are always met for FOREIGN KEY constraints for _, schema := range doc.Schemas { for _, table := range schema.Tables { - err := createConstraintsTable(ops.logger, ofs, nil, nil, schema, table, sql99.ConstraintTypePrimaryKey) + err := createConstraintsTable(ops.dbsteward, ofs, nil, nil, schema, table, sql99.ConstraintTypePrimaryKey) if err != nil { return err } @@ -1097,7 +1090,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm // use the dependency order to specify foreign keys in an order that will satisfy nested foreign keys and etc // TODO(feat) shouldn't this consider GlobalDBSteward.LimitToTables like BuildData does? for _, entry := range tableDep { - err := createConstraintsTable(ops.logger, ofs, nil, nil, entry.Schema, entry.Table, sql99.ConstraintTypeConstraint) + err := createConstraintsTable(ops.dbsteward, ofs, nil, nil, entry.Schema, entry.Table, sql99.ConstraintTypeConstraint) if err != nil { return err } @@ -1116,7 +1109,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm } } - err := createViewsOrdered(ops.logger, ofs, nil, doc) + err := createViewsOrdered(ops.dbsteward, ofs, nil, doc) if err != nil { return err } @@ -1125,7 +1118,7 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm for _, schema := range doc.Schemas { for _, view := range schema.Views { for _, grant := range view.Grants { - s, err := getViewGrantSql(ops.logger, doc, schema, view, grant) + s, err := getViewGrantSql(ops.dbsteward, doc, schema, view, grant) if err != nil { return err } @@ -1138,8 +1131,8 @@ func (ops *Operations) buildSchema(doc *ir.Definition, ofs output.OutputFileSegm return nil } -func buildData(l *slog.Logger, doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []*ir.TableRef) error { - limitToTables := lib.GlobalDBSteward.LimitToTables +func (ops *Operations) buildData(_ *slog.Logger, doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []*ir.TableRef) error { + limitToTables := ops.dbsteward.LimitToTables // use the dependency order to then write out the actual data inserts into the data sql file for _, entry := range tableDep { @@ -1157,7 +1150,7 @@ func buildData(l *slog.Logger, doc *ir.Definition, ofs output.OutputFileSegmente continue } } - s, err := getCreateDataSql(l, nil, nil, schema, table) + s, err := getCreateDataSql(ops, nil, nil, schema, table) if err != nil { return err } @@ -1211,7 +1204,7 @@ func buildData(l *slog.Logger, doc *ir.Definition, ofs output.OutputFileSegmente return nil } -func columnValueDefault(l *slog.Logger, schema *ir.Schema, table *ir.Table, columnName string, dataCol *ir.DataCol) (sql.ToSqlValue, error) { +func (ops *Operations) columnValueDefault(l *slog.Logger, schema *ir.Schema, table *ir.Table, columnName string, dataCol *ir.DataCol) (sql.ToSqlValue, error) { // if the column represents NULL, return a NULL value if dataCol.Null { return sql.ValueNull, nil @@ -1229,7 +1222,7 @@ func columnValueDefault(l *slog.Logger, schema *ir.Schema, table *ir.Table, colu } } - col, err := lib.GlobalDBSteward.NewDatabase.TryInheritanceGetColumn(schema, table, columnName) + col, err := ops.dbsteward.NewDatabase.TryInheritanceGetColumn(schema, table, columnName) if err != nil { return nil, fmt.Errorf("TryInheritanceGetColumn %w", err) } @@ -1251,7 +1244,7 @@ func columnValueDefault(l *slog.Logger, schema *ir.Schema, table *ir.Table, colu return sql.RawSql(col.Default), nil } - colType, err := getColumnType(l, lib.GlobalDBSteward.NewDatabase, schema, table, col) + colType, err := getColumnType(l, ops.dbsteward.NewDatabase, schema, table, col) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/operations_column_value_default_test.go b/lib/format/pgsql8/operations_column_value_default_test.go index f9091d5..93ce4b6 100644 --- a/lib/format/pgsql8/operations_column_value_default_test.go +++ b/lib/format/pgsql8/operations_column_value_default_test.go @@ -92,7 +92,6 @@ func TestOperations_ColumnValueDefault_UsesLiteralForInt(t *testing.T) { } func TestOperations_ColumnValueDefaultQuotesStrings(t *testing.T) { - defer resetGlobalDBSteward() val, err := getColumnValueDefault(&ir.Column{ Name: "foo", Type: "text", @@ -128,14 +127,17 @@ func getColumnValueDefault(def *ir.Column, data *ir.DataCol) (string, error) { }, }, } - lib.GlobalDBSteward.NewDatabase = doc + dbs := lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }) + dbs.NewDatabase = doc schema := doc.Schemas[0] table := schema.Tables[0] - ops := NewOperations().(*Operations) + ops := NewOperations(dbs).(*Operations) // TODO(go,nth) can we do this without also testing GetValueSql? - toVal, err := columnValueDefault(slog.Default(), schema, table, def.Name, data) + toVal, err := ops.columnValueDefault(slog.Default(), schema, table, def.Name, data) if err != nil { return "", err } diff --git a/lib/format/pgsql8/operations_extract_schema_test.go b/lib/format/pgsql8/operations_extract_schema_test.go index 8a76447..9d1bc35 100644 --- a/lib/format/pgsql8/operations_extract_schema_test.go +++ b/lib/format/pgsql8/operations_extract_schema_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/util" "github.com/jackc/pgtype" @@ -106,7 +107,9 @@ func TestOperations_ExtractSchema_Indexes(t *testing.T) { }, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -162,7 +165,9 @@ func TestOperations_ExtractSchema_CompoundUniqueConstraint(t *testing.T) { }, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -216,7 +221,9 @@ func TestOperations_ExtractSchema_TableComments(t *testing.T) { }, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -258,7 +265,9 @@ END; }, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -342,7 +351,9 @@ func TestOperations_ExtractSchema_FunctionArgs(t *testing.T) { }, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -385,7 +396,9 @@ func TestOperations_ExtractSchema_TableArrayType(t *testing.T) { }, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -445,7 +458,9 @@ func TestOperations_ExtractSchema_FKReferentialConstraints(t *testing.T) { }, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) @@ -504,7 +519,9 @@ func TestOperations_ExtractSchema_Sequences(t *testing.T) { {Schema: "public", Table: "user", Name: "user_pkey", Type: "p", Columns: []string{"user_id"}}, }, } - ops := NewOperations().(*Operations) + ops := NewOperations(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + })).(*Operations) actual, err := ops.pgToIR(pgDoc) if err != nil { t.Fatalf("Conversion failed: %+v", err) diff --git a/lib/format/pgsql8/pgsql8.go b/lib/format/pgsql8/pgsql8.go index 8d89d6e..ed0a6d4 100644 --- a/lib/format/pgsql8/pgsql8.go +++ b/lib/format/pgsql8/pgsql8.go @@ -1,11 +1,10 @@ package pgsql8 -import "github.com/dbsteward/dbsteward/lib/format" +import "github.com/dbsteward/dbsteward/lib" var GlobalSchema = NewSchema() -var GlobalXmlParser = NewXmlParser() -var GlobalLookup = &format.Lookup{ +var GlobalLookup = &lib.Lookup{ Schema: GlobalSchema, OperationsConstructor: NewOperations, } diff --git a/lib/format/pgsql8/pgsql8_main_test.go b/lib/format/pgsql8/pgsql8_main_test.go index c582789..e5e20de 100644 --- a/lib/format/pgsql8/pgsql8_main_test.go +++ b/lib/format/pgsql8/pgsql8_main_test.go @@ -1,22 +1,13 @@ package pgsql8 import ( - "os" - "testing" - "github.com/dbsteward/dbsteward/lib" - "github.com/dbsteward/dbsteward/lib/format" "github.com/dbsteward/dbsteward/lib/ir" ) -func TestMain(m *testing.M) { - resetGlobalDBSteward() - os.Exit(m.Run()) -} - -func setOldNewDocs(differ *diff, old, new *ir.Definition) { - lib.GlobalDBSteward.OldDatabase = old - lib.GlobalDBSteward.NewDatabase = new +func setOldNewDocs(dbs *lib.DBSteward, differ *diff, old, new *ir.Definition) { + dbs.OldDatabase = old + dbs.NewDatabase = new var err error if old != nil { differ.OldTableDependency, err = old.TableDependencyOrder() @@ -31,10 +22,3 @@ func setOldNewDocs(differ *diff, old, new *ir.Definition) { } } } - -func resetGlobalDBSteward() { - lib.GlobalDBSteward = lib.NewDBSteward(format.LookupMap{ - ir.SqlFormatPgsql8: GlobalLookup, - }) - lib.GlobalDBSteward.SqlFormat = ir.SqlFormatPgsql8 -} diff --git a/lib/format/pgsql8/role.go b/lib/format/pgsql8/role.go index 1582763..d9821f9 100644 --- a/lib/format/pgsql8/role.go +++ b/lib/format/pgsql8/role.go @@ -5,7 +5,6 @@ import ( "log/slog" "strings" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/util" ) @@ -113,7 +112,7 @@ func (ri *roleIndex) get(r string) string { return r } -func roleEnum(l *slog.Logger, doc *ir.Definition, role string) (string, error) { +func roleEnum(l *slog.Logger, doc *ir.Definition, role string, ignoreCustomRoles bool) (string, error) { if doc.Database == nil { // TODO(go,nth) somehow was incompletely constructed doc.Database = &ir.Database{ @@ -146,7 +145,7 @@ func roleEnum(l *slog.Logger, doc *ir.Definition, role string) (string, error) { return role, nil } - if !lib.GlobalDBSteward.IgnoreCustomRoles { + if !ignoreCustomRoles { l.Error(fmt.Sprintf("'%s' not in %+v", role, roles)) return "", fmt.Errorf("failed to confirm custom role: %s", role) } diff --git a/lib/format/pgsql8/schema.go b/lib/format/pgsql8/schema.go index 70de102..5490e0c 100644 --- a/lib/format/pgsql8/schema.go +++ b/lib/format/pgsql8/schema.go @@ -2,7 +2,6 @@ package pgsql8 import ( "fmt" - "log/slog" "strings" "github.com/dbsteward/dbsteward/lib/ir" @@ -20,12 +19,7 @@ func NewSchema() *Schema { return &Schema{} } -func (sc *Schema) logger() *slog.Logger { - // Hack until I can get proper constuctor ordering - return lib.GlobalDBSteward.Logger() -} - -func (s *Schema) GetCreationSql(schema *ir.Schema) ([]output.ToSql, error) { +func (s *Schema) GetCreationSql(dbs *lib.DBSteward, schema *ir.Schema) ([]output.ToSql, error) { // don't create the public schema if strings.EqualFold(schema.Name, "public") { return nil, nil @@ -36,7 +30,7 @@ func (s *Schema) GetCreationSql(schema *ir.Schema) ([]output.ToSql, error) { } if schema.Owner != "" { - owner, err := roleEnum(s.logger(), lib.GlobalDBSteward.NewDatabase, schema.Owner) + owner, err := roleEnum(dbs.Logger(), dbs.NewDatabase, schema.Owner, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -59,11 +53,11 @@ func (s *Schema) GetDropSql(schema *ir.Schema) []output.ToSql { } } -func (s *Schema) GetGrantSql(doc *ir.Definition, schema *ir.Schema, grant *ir.Grant) ([]output.ToSql, error) { +func (s *Schema) GetGrantSql(dbs *lib.DBSteward, doc *ir.Definition, schema *ir.Schema, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(s.logger(), lib.GlobalDBSteward.NewDatabase, role) + roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -87,7 +81,7 @@ func (s *Schema) GetGrantSql(doc *ir.Definition, schema *ir.Schema, grant *ir.Gr // SCHEMA IMPLICIT GRANTS // READYONLY USER PROVISION: grant usage on the schema for the readonly user // TODO(go,3) move this out of here, let this create just a single grant - roRole, err := roleEnum(s.logger(), lib.GlobalDBSteward.NewDatabase, ir.RoleReadOnly) + roRole, err := roleEnum(dbs.Logger(), dbs.NewDatabase, ir.RoleReadOnly, dbs.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/sequence.go b/lib/format/pgsql8/sequence.go index 0edaf17..2f67179 100644 --- a/lib/format/pgsql8/sequence.go +++ b/lib/format/pgsql8/sequence.go @@ -2,7 +2,6 @@ package pgsql8 import ( "fmt" - "log/slog" "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" @@ -11,7 +10,7 @@ import ( "github.com/dbsteward/dbsteward/lib/util" ) -func getCreateSequenceSql(l *slog.Logger, schema *ir.Schema, sequence *ir.Sequence) ([]output.ToSql, error) { +func getCreateSequenceSql(dbs *lib.DBSteward, schema *ir.Schema, sequence *ir.Sequence) ([]output.ToSql, error) { // TODO(go,3) put validation elsewhere cache, cacheValueSet := sequence.Cache.Maybe() if !cacheValueSet { @@ -39,7 +38,7 @@ func getCreateSequenceSql(l *slog.Logger, schema *ir.Schema, sequence *ir.Sequen if sequence.Owner != "" { // NOTE: Old dbsteward uses ALTER TABLE for this, which is valid according to docs, however // ALTER SEQUENCE also works in pgsql 8, and that's more correct - role, err := roleEnum(l, lib.GlobalDBSteward.NewDatabase, sequence.Owner) + role, err := roleEnum(dbs.Logger(), dbs.NewDatabase, sequence.Owner, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -67,11 +66,11 @@ func getDropSequenceSql(schema *ir.Schema, sequence *ir.Sequence) []output.ToSql } } -func getSequenceGrantSql(l *slog.Logger, schema *ir.Schema, seq *ir.Sequence, grant *ir.Grant) ([]output.ToSql, error) { +func getSequenceGrantSql(dbs *lib.DBSteward, schema *ir.Schema, seq *ir.Sequence, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(l, lib.GlobalDBSteward.NewDatabase, role) + roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -100,7 +99,7 @@ func getSequenceGrantSql(l *slog.Logger, schema *ir.Schema, seq *ir.Sequence, gr // SEQUENCE IMPLICIT GRANTS // READYONLY USER PROVISION: generate a SELECT on the sequence for the readonly user // TODO(go,3) move this out of here, let this create just a single grant - roRole, err := roleEnum(l, lib.GlobalDBSteward.NewDatabase, ir.RoleReadOnly) + roRole, err := roleEnum(dbs.Logger(), dbs.NewDatabase, ir.RoleReadOnly, dbs.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/table.go b/lib/format/pgsql8/table.go index 61662c4..032b0d2 100644 --- a/lib/format/pgsql8/table.go +++ b/lib/format/pgsql8/table.go @@ -15,11 +15,16 @@ import ( var includeColumnDefaultNextvalInCreateSql bool -func getCreateTableSql(l *slog.Logger, schema *ir.Schema, table *ir.Table) ([]output.ToSql, error) { +func getCreateTableSql(dbs *lib.DBSteward, schema *ir.Schema, table *ir.Table) ([]output.ToSql, error) { + l := dbs.Logger().With( + slog.String("table", table.Name), + slog.String("schema", schema.Name), + ) cols := []sql.ColumnDefinition{} colSetup := []output.ToSql{} for _, col := range table.Columns { - newCol, err := getReducedColumnDefinition(l, lib.GlobalDBSteward.NewDatabase, schema, table, col) + ll := l.With(slog.String("column", col.Name)) + newCol, err := getReducedColumnDefinition(ll, dbs.NewDatabase, schema, table, col) if err != nil { return nil, err } @@ -62,7 +67,7 @@ func getCreateTableSql(l *slog.Logger, schema *ir.Schema, table *ir.Table) ([]ou ddl = append(ddl, colSetup...) if table.Owner != "" { - role, err := roleEnum(l, lib.GlobalDBSteward.NewDatabase, table.Owner) + role, err := roleEnum(l, dbs.NewDatabase, table.Owner, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -120,11 +125,11 @@ func defineTableColumnDefaults(l *slog.Logger, schema *ir.Schema, table *ir.Tabl return out } -func getTableGrantSql(l *slog.Logger, schema *ir.Schema, table *ir.Table, grant *ir.Grant) ([]output.ToSql, error) { +func getTableGrantSql(dbs *lib.DBSteward, schema *ir.Schema, table *ir.Table, grant *ir.Grant) ([]output.ToSql, error) { roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(l, lib.GlobalDBSteward.NewDatabase, role) + roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -150,7 +155,7 @@ func getTableGrantSql(l *slog.Logger, schema *ir.Schema, table *ir.Table, grant // TABLE IMPLICIT GRANTS // READYONLY USER PROVISION: grant select on the table for the readonly user // TODO(go,3) move this out of here, let this create just a single grant - roRole, err := roleEnum(l, lib.GlobalDBSteward.NewDatabase, ir.RoleReadOnly) + roRole, err := roleEnum(dbs.Logger(), dbs.NewDatabase, ir.RoleReadOnly, dbs.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/table_test.go b/lib/format/pgsql8/table_test.go index 5262152..2fef5cd 100644 --- a/lib/format/pgsql8/table_test.go +++ b/lib/format/pgsql8/table_test.go @@ -1,9 +1,9 @@ package pgsql8 import ( - "log/slog" "testing" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -43,7 +43,7 @@ func TestTable_GetCreationSql_TableOptions(t *testing.T) { }, } - ddl, err := getCreateTableSql(slog.Default(), schema, schema.Tables[0]) + ddl, err := getCreateTableSql(lib.NewDBSteward(lib.LookupMap{}), schema, schema.Tables[0]) if err != nil { t.Fatal(err) } diff --git a/lib/format/pgsql8/type.go b/lib/format/pgsql8/type.go index 0c9d028..ad5ae21 100644 --- a/lib/format/pgsql8/type.go +++ b/lib/format/pgsql8/type.go @@ -2,7 +2,6 @@ package pgsql8 import ( "fmt" - "log/slog" "strings" "github.com/dbsteward/dbsteward/lib" @@ -120,12 +119,12 @@ func isIntType(spec string) bool { } // Change all table columns that are the given datatype to a placeholder type -func alterColumnTypePlaceholder(l *slog.Logger, differ *diff, datatype *ir.TypeDef) ([]*ir.ColumnRef, []output.ToSql, error) { +func alterColumnTypePlaceholder(dbs *lib.DBSteward, differ *diff, datatype *ir.TypeDef) ([]*ir.ColumnRef, []output.ToSql, error) { ddl := []output.ToSql{} cols := []*ir.ColumnRef{} for _, newTableRef := range differ.NewTableDependency { for _, newColumn := range newTableRef.Table.Columns { - columnType, err := getColumnType(l, lib.GlobalDBSteward.NewDatabase, newTableRef.Schema, newTableRef.Table, newColumn) + columnType, err := getColumnType(dbs.Logger(), dbs.NewDatabase, newTableRef.Schema, newTableRef.Table, newColumn) if err != nil { return nil, nil, err } diff --git a/lib/format/pgsql8/view.go b/lib/format/pgsql8/view.go index 36047b8..9be0416 100644 --- a/lib/format/pgsql8/view.go +++ b/lib/format/pgsql8/view.go @@ -2,7 +2,6 @@ package pgsql8 import ( "fmt" - "log/slog" "strings" "github.com/dbsteward/dbsteward/lib" @@ -12,7 +11,7 @@ import ( "github.com/dbsteward/dbsteward/lib/util" ) -func getCreateViewSql(l *slog.Logger, schema *ir.Schema, view *ir.View) ([]output.ToSql, error) { +func getCreateViewSql(dbs *lib.DBSteward, schema *ir.Schema, view *ir.View) ([]output.ToSql, error) { ref := sql.ViewRef{Schema: schema.Name, View: view.Name} query := view.TryGetViewQuery(ir.SqlFormatPgsql8) util.Assert(query != nil, "Calling View.GetCreationSql for a view not defined for this sqlformat") @@ -31,7 +30,7 @@ func getCreateViewSql(l *slog.Logger, schema *ir.Schema, view *ir.View) ([]outpu }) } if view.Owner != "" { - role, err := roleEnum(l, lib.GlobalDBSteward.NewDatabase, view.Owner) + role, err := roleEnum(dbs.Logger(), dbs.NewDatabase, view.Owner, dbs.IgnoreCustomRoles) if err != nil { return nil, err } @@ -52,12 +51,12 @@ func getDropViewSql(schema *ir.Schema, view *ir.View) []output.ToSql { } } -func getViewGrantSql(l *slog.Logger, doc *ir.Definition, schema *ir.Schema, view *ir.View, grant *ir.Grant) ([]output.ToSql, error) { +func getViewGrantSql(dbs *lib.DBSteward, doc *ir.Definition, schema *ir.Schema, view *ir.View, grant *ir.Grant) ([]output.ToSql, error) { // NOTE: pgsql views use table grants! roles := make([]string, len(grant.Roles)) var err error for i, role := range grant.Roles { - roles[i], err = roleEnum(l, lib.GlobalDBSteward.NewDatabase, role) + roles[i], err = roleEnum(dbs.Logger(), dbs.NewDatabase, role, dbs.IgnoreCustomRoles) if err != nil { return nil, err } diff --git a/lib/format/pgsql8/xml_parser.go b/lib/format/pgsql8/xml_parser.go index 3a6ca72..af5cb49 100644 --- a/lib/format/pgsql8/xml_parser.go +++ b/lib/format/pgsql8/xml_parser.go @@ -5,14 +5,14 @@ import ( "log/slog" "strconv" - "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" + "github.com/dbsteward/dbsteward/lib/output" "github.com/dbsteward/dbsteward/lib/util" "github.com/pkg/errors" ) type XmlParser struct { - logger *slog.Logger + quoter output.Quoter } type slonyRange struct { @@ -49,23 +49,15 @@ func tryNewSlonyRange(firstStr, lastStr string, parts int) (*slonyRange, error) return &slonyRange{first, last}, nil } -func NewXmlParser() *XmlParser { - return &XmlParser{} +func NewXmlParser(quoter output.Quoter) *XmlParser { + return &XmlParser{quoter: quoter} } -// @hack until we get proper instantiation ordering (i.e. remove all globals) -func (parser *XmlParser) Logger() *slog.Logger { - if parser.logger == nil { - parser.logger = lib.GlobalDBSteward.Logger() - } - return parser.logger -} - -func (parser *XmlParser) Process(doc *ir.Definition) error { +func (parser *XmlParser) Process(l *slog.Logger, doc *ir.Definition) error { for _, schema := range doc.Schemas { for _, table := range schema.Tables { if table.Partitioning != nil { - parser.Logger().Warn(fmt.Sprintf("Table %s.%s definies partition which is only partially supported at this time", schema.Name, table.Name)) + l.Warn(fmt.Sprintf("Table %s.%s definies partition which is only partially supported at this time", schema.Name, table.Name)) return parser.expandPartitionedTable(doc, schema, table) } } diff --git a/lib/format/pgsql8/xml_parser_partition_modulo.go b/lib/format/pgsql8/xml_parser_partition_modulo.go index 69dca4b..f3f94d4 100644 --- a/lib/format/pgsql8/xml_parser_partition_modulo.go +++ b/lib/format/pgsql8/xml_parser_partition_modulo.go @@ -95,7 +95,7 @@ func (p *XmlParser) createModuloPartitionTables(schema *ir.Schema, table *ir.Tab Name: fmt.Sprintf("%s_p_%s_chk", table.Name, partNum), Type: ir.ConstraintTypeCheck, // TODO(go,3) use higher level rep instead of xml rep here to resolve need for string-level quoting at this point - Definition: fmt.Sprintf("((%s %% %d) = %d)", NewOperations().GetQuoter().QuoteColumn(opts.column), opts.parts, i), + Definition: fmt.Sprintf("((%s %% %d) = %d)", p.quoter.QuoteColumn(opts.column), opts.parts, i), }) for _, index := range table.Indexes { @@ -147,14 +147,14 @@ func (p *XmlParser) createModuloPartitionTables(schema *ir.Schema, table *ir.Tab func (p *XmlParser) createModuloPartitionTrigger(schema *ir.Schema, table *ir.Table, partSchema *ir.Schema, opts *moduloPartition) { funcDef := fmt.Sprintf("DECLARE\n\tmod_result INT;\nBEGIN\n\tmod_result := NEW.%s %% %d;\n", - NewOperations().GetQuoter().QuoteColumn(opts.column), opts.parts) + quoter.QuoteColumn(opts.column), opts.parts) for i := 0; i < opts.parts; i++ { funcDef += "\t" if i != 0 { funcDef += "ELSE" } funcDef += fmt.Sprintf("IF (mod_result = %d) THEN\n\t\tINSERT INTO %s VALUES (NEW.*);\n", - i, NewOperations().GetQuoter().QualifyTable(partSchema.Name, opts.tableName(i))) + i, quoter.QualifyTable(partSchema.Name, opts.tableName(i))) } funcDef += "\tEND IF;\n\tRETURN NULL;\nEND;" diff --git a/lib/format/pgsql8/xml_parser_test.go b/lib/format/pgsql8/xml_parser_test.go index cedbbfb..2f667ec 100644 --- a/lib/format/pgsql8/xml_parser_test.go +++ b/lib/format/pgsql8/xml_parser_test.go @@ -1,10 +1,11 @@ -package pgsql8_test +package pgsql8 import ( "fmt" + "log/slog" "testing" - "github.com/dbsteward/dbsteward/lib/format/pgsql8" + "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/ir" "github.com/stretchr/testify/assert" ) @@ -162,7 +163,10 @@ END;`, } // note that Process mutates the document in place - err := pgsql8.GlobalXmlParser.Process(doc) + xmlParser := NewXmlParser(defaultQuoter(lib.NewDBSteward(lib.LookupMap{ + ir.SqlFormatPgsql8: GlobalLookup, + }))) + err := xmlParser.Process(slog.Default(), doc) if err != nil { t.Fatal(err) } diff --git a/lib/format/sql99/operations.go b/lib/format/sql99/operations.go deleted file mode 100644 index a4ea2ff..0000000 --- a/lib/format/sql99/operations.go +++ /dev/null @@ -1,18 +0,0 @@ -package sql99 - -import ( - "github.com/dbsteward/dbsteward/lib/format" -) - -type Operations struct { - format.Operations -} - -// NOTE: Sql99.OperationsIface will need to be provided after invoking: -// parent := &sql99.Sql99{} -// child := &pgsql8.Pgsql8{parent} -// child.sql99.OperationsIface = child -// Yes, this is super weird, and a holdover from PHP. TODO(go,3) get rid of this -func NewOperations() *Operations { - return &Operations{} -} diff --git a/main.go b/main.go index 101fc6b..688a1de 100644 --- a/main.go +++ b/main.go @@ -2,16 +2,15 @@ package main import ( "github.com/dbsteward/dbsteward/lib" - "github.com/dbsteward/dbsteward/lib/format" "github.com/dbsteward/dbsteward/lib/format/pgsql8" "github.com/dbsteward/dbsteward/lib/ir" ) func main() { // correlates to bin/dbsteward - lib.GlobalDBSteward = lib.NewDBSteward(format.LookupMap{ + dbsteward := lib.NewDBSteward(lib.LookupMap{ ir.SqlFormatPgsql8: pgsql8.GlobalLookup, }) - lib.GlobalDBSteward.ArgParse() - lib.GlobalDBSteward.Info("Done") + dbsteward.ArgParse() + dbsteward.Info("Done") } diff --git a/xmlpostgresintegration_test.go b/xmlpostgresintegration_test.go index 10c64ca..c0b6b4e 100644 --- a/xmlpostgresintegration_test.go +++ b/xmlpostgresintegration_test.go @@ -9,7 +9,6 @@ import ( "github.com/dbsteward/dbsteward/lib" "github.com/dbsteward/dbsteward/lib/encoding/xml" - "github.com/dbsteward/dbsteward/lib/format" "github.com/dbsteward/dbsteward/lib/format/pgsql8" "github.com/dbsteward/dbsteward/lib/ir" ) @@ -37,10 +36,10 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - lib.GlobalDBSteward = lib.NewDBSteward(format.LookupMap{ + dbs := lib.NewDBSteward(lib.LookupMap{ ir.SqlFormatPgsql8: pgsql8.GlobalLookup, }) - lib.GlobalDBSteward.SqlFormat = ir.SqlFormatPgsql8 + dbs.SqlFormat = ir.SqlFormatPgsql8 err = pgsql8.CreateRoleIfNotExists(c, def1.Database.Roles.Application) if err != nil { t.Fatal(err) @@ -57,7 +56,7 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - ops := pgsql8.NewOperations().(*pgsql8.Operations) + ops := pgsql8.NewOperations(dbs).(*pgsql8.Operations) statements, err := ops.CreateStatements(*def1) if err != nil { t.Fatal(err) @@ -82,7 +81,7 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - ops = pgsql8.NewOperations().(*pgsql8.Operations) + ops = pgsql8.NewOperations(dbs).(*pgsql8.Operations) statements, err = ops.Upgrade(slog.Default(), def1, def2) if err != nil { t.Fatal(err) @@ -103,7 +102,7 @@ func TestXMLPostgresIngegration(t *testing.T) { if err != nil { t.Fatal(err) } - ops = pgsql8.NewOperations().(*pgsql8.Operations) + ops = pgsql8.NewOperations(dbs).(*pgsql8.Operations) _, err = ops.ExtractSchemaConn(context.TODO(), c) if err != nil { t.Fatal(err)