Skip to content

Commit

Permalink
Migrate functions from dbx -> ir where appropriate and reasonable to …
Browse files Browse the repository at this point in the history
…do so
  • Loading branch information
williammoran committed Apr 30, 2024
1 parent d49f365 commit ca3e5dd
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 145 deletions.
122 changes: 38 additions & 84 deletions lib/dbx.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ func NewDBX() *DBX {
return &DBX{}
}

func (self *DBX) SetDefaultSchema(def *ir.Definition, schema string) *ir.Schema {
self.defaultSchema = def.GetOrCreateSchemaNamed(schema)
return self.defaultSchema
func (dbx *DBX) SetDefaultSchema(def *ir.Definition, schema string) *ir.Schema {
dbx.defaultSchema = def.GetOrCreateSchemaNamed(schema)
return dbx.defaultSchema
}
func (self *DBX) GetDefaultSchema() *ir.Schema {
return self.defaultSchema
func (dbx *DBX) GetDefaultSchema() *ir.Schema {
return dbx.defaultSchema
}

func (self *DBX) BuildStagedSql(doc *ir.Definition, ofs output.OutputFileSegmenter, stage ir.SqlStage) {
func (dbx *DBX) BuildStagedSql(doc *ir.Definition, ofs output.OutputFileSegmenter, stage ir.SqlStage) {
if stage == "" {
ofs.Write("\n-- NON-STAGED SQL COMMANDS\n")
} else {
Expand All @@ -44,18 +44,18 @@ func (self *DBX) BuildStagedSql(doc *ir.Definition, ofs output.OutputFileSegment
ofs.Write("\n")
}

func (self *DBX) GetTerminalForeignColumn(doc *ir.Definition, schema *ir.Schema, table *ir.Table, column *ir.Column) *ir.Column {
fkey := self.ResolveForeignKeyColumn(doc, schema, table, column)
func (dbx *DBX) GetTerminalForeignColumn(doc *ir.Definition, schema *ir.Schema, table *ir.Table, column *ir.Column) *ir.Column {
fkey := dbx.ResolveForeignKeyColumn(doc, schema, table, column)
fcol := fkey.Columns[0]

if fcol.Type == "" && fcol.ForeignTable != "" {
GlobalDBSteward.Trace("Seeking nested foreign key for %s", fkey.String())
return self.GetTerminalForeignColumn(doc, fkey.Schema, fkey.Table, fcol)
return dbx.GetTerminalForeignColumn(doc, fkey.Schema, fkey.Table, fcol)
}
return fcol
}

func (self *DBX) ResolveForeignKeyColumn(doc *ir.Definition, schema *ir.Schema, table *ir.Table, column *ir.Column) ir.Key {
func (dbx *DBX) ResolveForeignKeyColumn(doc *ir.Definition, schema *ir.Schema, table *ir.Table, column *ir.Column) ir.Key {
// this used to be called format_constraint::foreign_key_lookup() in v1
// most of the functionality got split to the more general ResolveForeignKey
foreign := column.TryGetReferencedKey()
Expand All @@ -66,11 +66,12 @@ func (self *DBX) ResolveForeignKeyColumn(doc *ir.Definition, schema *ir.Schema,
Table: table,
Columns: []*ir.Column{column},
}
return self.ResolveForeignKey(doc, local, *foreign)
return dbx.ResolveForeignKey(doc, local, *foreign)
}

func (self *DBX) ResolveForeignKey(doc *ir.Definition, localKey ir.Key, foreignKey ir.KeyNames) ir.Key {
fref := self.ResolveSchemaTable(doc, localKey.Schema, foreignKey.Schema, foreignKey.Table, "foreign key")
func (dbx *DBX) ResolveForeignKey(doc *ir.Definition, localKey ir.Key, foreignKey ir.KeyNames) ir.Key {
fref, err := doc.ResolveSchemaTable(localKey.Schema, foreignKey.Schema, foreignKey.Table, "foreign key")
GlobalDBSteward.FatalIfError(err, "gathering foreign keys")

// if we didn't ask for specific foreign columns, but we have local columns, use those
if len(foreignKey.Columns) == 0 {
Expand All @@ -97,7 +98,8 @@ func (self *DBX) ResolveForeignKey(doc *ir.Definition, localKey ir.Key, foreignK
col = localKey.Columns[i].Name
}

fCol := self.TryInheritanceGetColumn(doc, fref.Schema, fref.Table, col)
fCol, err := doc.TryInheritanceGetColumn(fref.Schema, fref.Table, col)
GlobalDBSteward.FatalIfError(err, "TryInheritanceGetColumn")
if fCol == nil {
GlobalDBSteward.Fatal("Failed to find foreign column %s in %s referenced by %s", col, foreignKey.String(), localKey.String())
}
Expand All @@ -107,75 +109,37 @@ func (self *DBX) ResolveForeignKey(doc *ir.Definition, localKey ir.Key, foreignK
return out
}

func (self *DBX) ResolveSchemaTable(doc *ir.Definition, localSchema *ir.Schema, schemaName, tableName string, refType string) ir.TableRef {
fSchema := localSchema
if schemaName != "" {
fSchema = doc.TryGetSchemaNamed(schemaName)
if fSchema == nil {
GlobalDBSteward.Fatal("%s reference to unknown schema %s", refType, schemaName)
}
}

fTable := fSchema.TryGetTableNamed(tableName)
if fTable == nil {
GlobalDBSteward.Fatal("%s reference to unknown table %s.%s", refType, fSchema.Name, tableName)
}

return ir.TableRef{fSchema, fTable}
}

// attempts to find the new table that claims it is renamed from the old table
// this is the "forwards looking" version of RenamedTableCheckPointer
func (self *DBX) TryGetTableFormerlyKnownAs(newDoc *ir.Definition, oldSchema *ir.Schema, oldTable *ir.Table) *ir.TableRef {
// TODO(go,nth) can we remove the assertion in favor of just returning nil? or should callers continue to check IgnoreOldNames themselves?
util.Assert(!GlobalDBSteward.IgnoreOldNames, "Should not attempt to look up renamed tables if IgnoreOldNames is set")

// TODO(go,3) move to model, and/or compositing pass
for _, newSchema := range newDoc.Schemas {
for _, newTable := range newSchema.Tables {
if newTable.OldTableName != "" || newTable.OldSchemaName != "" {
oldTableName := util.CoalesceStr(newTable.OldTableName, newTable.Name)
oldSchemaName := util.CoalesceStr(newTable.OldSchemaName, newSchema.Name)
if strings.EqualFold(oldSchema.Name, oldSchemaName) && strings.EqualFold(oldTable.Name, oldTableName) {
return &ir.TableRef{newSchema, newTable}
}
}
}
}
return nil
}

// attempts to find, and sanity checks, the table pointed to by oldSchema/TableName attributes
// this is the "backwards looking" version of TryGetTableFormerlyKnownAs
// TODO(go,nth) rename this, clean it up
func (self *DBX) RenamedTableCheckPointer(oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) (*ir.Schema, *ir.Table) {
func (dbx *DBX) RenamedTableCheckPointer(oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table) (*ir.Schema, *ir.Table) {
if newSchema == nil || newTable == nil {
return oldSchema, oldTable
}

isRenamed, err := self.IsRenamedTable(newSchema, newTable)
isRenamed, err := dbx.IsRenamedTable(newSchema, newTable)
GlobalDBSteward.FatalIfError(err, "while checking table rename status")
if !isRenamed {
return oldSchema, oldTable
}

if newTable.OldSchemaName != "" {
oldSchema = self.GetOldTableSchema(newSchema, newTable)
oldSchema = dbx.GetOldTableSchema(newSchema, newTable)
if oldSchema == nil {
GlobalDBSteward.Fatal("Sanity failure: %s.%s has oldSchemaName attribute but old_schema not found", newSchema.Name, newTable.Name)
}
} else if oldSchema == nil {
GlobalDBSteward.Fatal("Sanity failure: %s.%s has oldTableName attribute but passed old_schema is not defined", newSchema.Name, newTable.Name)
}

oldTable = self.GetOldTable(newSchema, newTable)
oldTable = dbx.GetOldTable(newSchema, newTable)
if oldTable == nil {
GlobalDBSteward.Fatal("Sanity failure: %s.%s has oldTableName attribute, but table %s.%s not found", newSchema.Name, newTable.Name, oldSchema.Name, newTable.OldTableName)
}
return oldSchema, oldTable
}

func (self *DBX) IsRenamedTable(schema *ir.Schema, table *ir.Table) (bool, error) {
func (dbx *DBX) IsRenamedTable(schema *ir.Schema, table *ir.Table) (bool, error) {
if GlobalDBSteward.IgnoreOldNames {
return false, nil
}
Expand All @@ -188,7 +152,7 @@ func (self *DBX) IsRenamedTable(schema *ir.Schema, table *ir.Table) (bool, error
return true, errors.Errorf("oldTableName panic - new schema %s still contains table named %s", schema.Name, table.OldTableName)
}

oldSchema := self.GetOldTableSchema(schema, table)
oldSchema := dbx.GetOldTableSchema(schema, table)
if oldSchema != nil {
if oldSchema.TryGetTableNamed(table.OldTableName) == nil {
return true, errors.Errorf("oldTableName panic - old schema %s does not contain table named %s", oldSchema.Name, table.OldTableName)
Expand All @@ -205,7 +169,7 @@ func (self *DBX) IsRenamedTable(schema *ir.Schema, table *ir.Table) (bool, error
return false, nil
}

func (self *DBX) GetOldTableSchema(schema *ir.Schema, table *ir.Table) *ir.Schema {
func (dbx *DBX) GetOldTableSchema(schema *ir.Schema, table *ir.Table) *ir.Schema {
if table.OldSchemaName == "" {
return schema
}
Expand All @@ -215,15 +179,15 @@ func (self *DBX) GetOldTableSchema(schema *ir.Schema, table *ir.Table) *ir.Schem
return GlobalDBSteward.OldDatabase.TryGetSchemaNamed(table.OldSchemaName)
}

func (self *DBX) GetOldTable(schema *ir.Schema, table *ir.Table) *ir.Table {
func (dbx *DBX) GetOldTable(schema *ir.Schema, table *ir.Table) *ir.Table {
if table.OldTableName == "" {
return nil
}
oldSchema := self.GetOldTableSchema(schema, table)
oldSchema := dbx.GetOldTableSchema(schema, table)
return oldSchema.TryGetTableNamed(table.OldTableName)
}

func (self *DBX) TableDependencyOrder(doc *ir.Definition) []*ir.TableRef {
func (dbx *DBX) TableDependencyOrder(doc *ir.Definition) []*ir.TableRef {
// first, build forward and reverse adjacency lists
// forwards: a mapping of local table => foreign tables that it references
// reverse: a mapping of foreign table => local tables that reference it
Expand All @@ -237,7 +201,7 @@ func (self *DBX) TableDependencyOrder(doc *ir.Definition) []*ir.TableRef {

for _, schema := range doc.Schemas {
for _, table := range schema.Tables {
curr := ir.TableRef{schema, table}
curr := ir.TableRef{Schema: schema, Table: table}
if len(reverse[curr]) == 0 {
reverse[curr] = []ir.TableRef{}
}
Expand All @@ -246,7 +210,7 @@ func (self *DBX) TableDependencyOrder(doc *ir.Definition) []*ir.TableRef {
// add that dep as something this table depends on
// add this table as something depending on that dep
foreigns := forward.GetOrInit(curr, init)
for _, dep := range self.getTableDependencies(doc, schema, table) {
for _, dep := range dbx.getTableDependencies(doc, schema, table) {
*foreigns = append(*foreigns, dep)
reverse[dep] = append(reverse[dep], curr)
}
Expand Down Expand Up @@ -335,26 +299,29 @@ func (self *DBX) TableDependencyOrder(doc *ir.Definition) []*ir.TableRef {
return out
}

func (self *DBX) getTableDependencies(doc *ir.Definition, schema *ir.Schema, table *ir.Table) []ir.TableRef {
func (dbx *DBX) getTableDependencies(doc *ir.Definition, schema *ir.Schema, table *ir.Table) []ir.TableRef {
out := []ir.TableRef{}
// gather foreign keys on the columns
for _, column := range table.Columns {
if column.ForeignTable != "" {
fref := GlobalDBX.ResolveSchemaTable(doc, schema, column.ForeignSchema, column.ForeignTable, "column foreignKey")
fref, err := doc.ResolveSchemaTable(schema, column.ForeignSchema, column.ForeignTable, "column foreignKey")
GlobalDBSteward.FatalIfError(err, "gathering foreign keys")
out = append(out, fref)
}
}

// gather explicit foreign keys
for _, fk := range table.ForeignKeys {
fref := GlobalDBX.ResolveSchemaTable(doc, schema, fk.ForeignSchema, fk.ForeignTable, "foreignKey element")
fref, err := doc.ResolveSchemaTable(schema, fk.ForeignSchema, fk.ForeignTable, "foreignKey element")
GlobalDBSteward.FatalIfError(err, "gathering explicit foreign keys")
out = append(out, fref)
}

// gather constraints
for _, constraint := range table.Constraints {
if constraint.ForeignTable != "" {
fref := GlobalDBX.ResolveSchemaTable(doc, schema, constraint.ForeignSchema, constraint.ForeignTable, "FOREIGN KEY constraint")
fref, err := doc.ResolveSchemaTable(schema, constraint.ForeignSchema, constraint.ForeignTable, "FOREIGN KEY constraint")
GlobalDBSteward.FatalIfError(err, "gathering constraints")
out = append(out, fref)
}
}
Expand All @@ -365,25 +332,12 @@ func (self *DBX) getTableDependencies(doc *ir.Definition, schema *ir.Schema, tab
return out
}

func (self *DBX) TryInheritanceGetColumn(doc *ir.Definition, schema *ir.Schema, table *ir.Table, columnName string) *ir.Column {
// TODO(go,3) move to model
column := table.TryGetColumnNamed(columnName)

// just keep walking up the inheritance chain so long as there's a link
for column == nil && table.InheritsTable != "" {
ref := GlobalDBX.ResolveSchemaTable(doc, schema, table.InheritsSchema, table.InheritsTable, "inheritance")
table = ref.Table
column = table.TryGetColumnNamed(columnName)
}

return column
}

func (self *DBX) TryInheritanceGetColumns(doc *ir.Definition, schema *ir.Schema, table *ir.Table, columnNames []string) ([]*ir.Column, bool) {
func (dbx *DBX) TryInheritanceGetColumns(doc *ir.Definition, schema *ir.Schema, table *ir.Table, columnNames []string) ([]*ir.Column, bool) {
// TODO(go,nth) this could be more efficient (but more complicated) if we did all the columns at once, one table at a time
columns := make([]*ir.Column, len(columnNames))
for i, colName := range columnNames {
column := self.TryInheritanceGetColumn(doc, schema, table, colName)
column, err := doc.TryInheritanceGetColumn(schema, table, colName)
GlobalDBSteward.FatalIfError(err, "gathering explicit foreign keys")
if column == nil {
return nil, false
}
Expand Down
2 changes: 1 addition & 1 deletion lib/format/pgsql8/diff_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ func dropTable(ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *i
return
}
if !lib.GlobalDBSteward.IgnoreOldNames {
renamedRef := lib.GlobalDBX.TryGetTableFormerlyKnownAs(lib.GlobalDBSteward.NewDatabase, oldSchema, oldTable)
renamedRef := lib.GlobalDBSteward.NewDatabase.TryGetTableFormerlyKnownAs(oldSchema, oldTable)
if renamedRef != nil {
ofs.Write("-- DROP TABLE %s.%s omitted: new table %s indicates it is her replacement", oldSchema.Name, oldTable.Name, renamedRef)
return
Expand Down
14 changes: 10 additions & 4 deletions lib/format/pgsql8/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,10 @@ func (ops *Operations) pgToIR(pgDoc structure) (*ir.Definition, error) {

// NEW(2) if the table inherits from the parent, remove any inherited objects
if table.InheritsTable != "" || table.InheritsSchema != "" {
parentRef := lib.GlobalDBX.ResolveSchemaTable(doc, schema, table.InheritsSchema, table.InheritsTable, "inheritance")
parentRef, err := doc.ResolveSchemaTable(schema, table.InheritsSchema, table.InheritsTable, "inheritance")
if err != nil {
return nil, err
}
for _, parentColumn := range parentRef.Table.Columns {
column := table.TryGetColumnNamed(parentColumn.Name)
if column != nil && column.EqualsInherited(parentColumn) {
Expand Down Expand Up @@ -1050,7 +1053,8 @@ func buildData(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []*i
if util.Contains(dataCols, pkCol) {
// TODO(go,3) seems like this could be refactored better by putting much of the lookup
// into the model structs
pk := lib.GlobalDBX.TryInheritanceGetColumn(doc, schema, table, pkCol)
pk, err := doc.TryInheritanceGetColumn(schema, table, pkCol)
lib.GlobalDBSteward.FatalIfError(err, "TryInheritanceGetColumn")
if pk == nil {
lib.GlobalDBSteward.Fatal("Failed to find primary key column '%s' for %s.%s",
pkCol, schema.Name, table.Name)
Expand All @@ -1071,7 +1075,8 @@ func buildData(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []*i
// check if primary key columns are columns of this table
// TODO(go,3) does this check belong here? should there be some kind of post-parse validation?
for _, columnName := range table.PrimaryKey {
col := lib.GlobalDBX.TryInheritanceGetColumn(doc, schema, table, columnName)
col, err := doc.TryInheritanceGetColumn(schema, table, columnName)
lib.GlobalDBSteward.FatalIfError(err, "TryInheritanceGetColumn")
if col == nil {
lib.GlobalDBSteward.Fatal("Declared primary key column (%s) does not exist as column in table %s.%s",
columnName, schema.Name, table.Name)
Expand Down Expand Up @@ -1101,7 +1106,8 @@ func columnValueDefault(schema *ir.Schema, table *ir.Table, columnName string, d
}
}

col := lib.GlobalDBX.TryInheritanceGetColumn(lib.GlobalDBSteward.NewDatabase, schema, table, columnName)
col, err := lib.GlobalDBSteward.NewDatabase.TryInheritanceGetColumn(schema, table, columnName)
lib.GlobalDBSteward.FatalIfError(err, "TryInheritanceGetColumn")
if col == nil {
lib.GlobalDBSteward.Fatal("Failed to find table %s.%s column %s for default value check", schema.Name, table.Name, columnName)
}
Expand Down
Loading

0 comments on commit ca3e5dd

Please sign in to comment.