Skip to content

Commit

Permalink
Remove GlobalConstraints, refactor it into standalone functions, and …
Browse files Browse the repository at this point in the history
…some general cleanup of that code
  • Loading branch information
williammoran committed Apr 25, 2024
1 parent 84652ba commit af30da3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 29 deletions.
25 changes: 9 additions & 16 deletions lib/format/pgsql8/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,11 @@ import (
"github.com/dbsteward/dbsteward/lib/util"
)

type Constraint struct {
}

func NewConstraint() *Constraint {
return &Constraint{}
}

// TODO(go,pgsql) make sure this is tested _thoroughly_
// TODO(go,core) lift this to sql99
// ConstraintTypeAll includes PrimaryKey,Constraint,Foreign
// sql99.ConstraintType
func (self *Constraint) GetTableConstraints(doc *ir.Definition, schema *ir.Schema, table *ir.Table, ct sql99.ConstraintType) []*sql99.TableConstraint {
func getTableConstraints(doc *ir.Definition, schema *ir.Schema, table *ir.Table, ct sql99.ConstraintType) []*sql99.TableConstraint {
if table == nil {
return nil
}
Expand Down Expand Up @@ -191,27 +184,27 @@ func (self *Constraint) GetTableConstraints(doc *ir.Definition, schema *ir.Schem
return constraints
}

func (self *Constraint) TryGetTableConstraintNamed(doc *ir.Definition, schema *ir.Schema, table *ir.Table, name string, constraintType sql99.ConstraintType) *sql99.TableConstraint {
func tryGetTableConstraintNamed(doc *ir.Definition, schema *ir.Schema, table *ir.Table, name string, constraintType sql99.ConstraintType) *sql99.TableConstraint {
// TODO(feat) can make this a little more performant if we pass constraint type in
for _, constraint := range self.GetTableConstraints(doc, schema, table, constraintType) {
for _, constraint := range getTableConstraints(doc, schema, table, constraintType) {
if strings.EqualFold(constraint.Name, name) {
return constraint
}
}
return nil
}

func (self *Constraint) GetDropSql(constraint *sql99.TableConstraint) []output.ToSql {
func getTableConstraintDropSql(constraint *sql99.TableConstraint) []output.ToSql {
return []output.ToSql{
&sql.ConstraintDrop{
Table: sql.TableRef{constraint.Schema.Name, constraint.Table.Name},
Table: sql.TableRef{Schema: constraint.Schema.Name, Table: constraint.Table.Name},
Constraint: constraint.Name,
},
}
}

func (self *Constraint) GetCreationSql(constraint *sql99.TableConstraint) []output.ToSql {
table := sql.TableRef{constraint.Schema.Name, constraint.Table.Name}
func getTableContraintCreationSql(constraint *sql99.TableConstraint) []output.ToSql {
table := sql.TableRef{Schema: constraint.Schema.Name, Table: constraint.Table.Name}

// if there's a text definition, prefer that; it should have come verbatim from the xml
if constraint.TextDefinition != "" {
Expand Down Expand Up @@ -254,7 +247,7 @@ func (self *Constraint) GetCreationSql(constraint *sql99.TableConstraint) []outp
Table: table,
Constraint: constraint.Name,
LocalColumns: localCols,
ForeignTable: sql.TableRef{constraint.ForeignSchema.Name, constraint.ForeignTable.Name},
ForeignTable: sql.TableRef{Schema: constraint.ForeignSchema.Name, Table: constraint.ForeignTable.Name},
ForeignColumns: foreignCols,
OnUpdate: constraint.ForeignOnUpdate,
OnDelete: constraint.ForeignOnDelete,
Expand All @@ -266,7 +259,7 @@ func (self *Constraint) GetCreationSql(constraint *sql99.TableConstraint) []outp
return nil
}

func (self *Constraint) DependsOnRenamedTable(doc *ir.Definition, constraint *sql99.TableConstraint) bool {
func constraintDependsOnRenamedTable(doc *ir.Definition, constraint *sql99.TableConstraint) bool {
if lib.GlobalDBSteward.IgnoreOldNames {
return false
}
Expand Down
24 changes: 12 additions & 12 deletions lib/format/pgsql8/diff_constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,25 @@ func (self *DiffConstraints) CreateConstraintsTable(ofs output.OutputFileSegment
lib.GlobalDBSteward.FatalIfError(err, "while checking if table was renamed")
if isRenamed {
// remove all constraints and recreate with new table name conventions
for _, constraint := range GlobalConstraint.GetTableConstraints(lib.GlobalDBSteward.OldDatabase, oldSchema, oldTable, constraintType) {
for _, constraint := range getTableConstraints(lib.GlobalDBSteward.OldDatabase, oldSchema, oldTable, constraintType) {
// rewrite the constraint definer to refer to the new table
// so the constraint by the old, but part of the new table
// will be referenced properly in the drop statement
constraint.Schema = newSchema
constraint.Table = newTable
ofs.WriteSql(GlobalConstraint.GetDropSql(constraint)...)
ofs.WriteSql(getTableConstraintDropSql(constraint)...)
}

// add all still-defined constraints back and any new ones to the table
for _, constraint := range GlobalConstraint.GetTableConstraints(lib.GlobalDBSteward.NewDatabase, newSchema, newTable, constraintType) {
ofs.WriteSql(GlobalConstraint.GetCreationSql(constraint)...)
for _, constraint := range getTableConstraints(lib.GlobalDBSteward.NewDatabase, newSchema, newTable, constraintType) {
ofs.WriteSql(getTableContraintCreationSql(constraint)...)
}

return
}

for _, constraint := range self.GetNewConstraints(oldSchema, oldTable, newSchema, newTable, constraintType) {
ofs.WriteSql(GlobalConstraint.GetCreationSql(constraint)...)
ofs.WriteSql(getTableContraintCreationSql(constraint)...)
}
}

Expand All @@ -65,7 +65,7 @@ func (self *DiffConstraints) DropConstraints(ofs output.OutputFileSegmenter, old

func (self *DiffConstraints) DropConstraintsTable(ofs output.OutputFileSegmenter, oldSchema *ir.Schema, oldTable *ir.Table, newSchema *ir.Schema, newTable *ir.Table, constraintType sql99.ConstraintType) {
for _, constraint := range self.GetOldConstraints(oldSchema, oldTable, newSchema, newTable, constraintType) {
ofs.WriteSql(GlobalConstraint.GetDropSql(constraint)...)
ofs.WriteSql(getTableConstraintDropSql(constraint)...)
}
}

Expand All @@ -74,9 +74,9 @@ func (self *DiffConstraints) GetOldConstraints(oldSchema *ir.Schema, oldTable *i
if newTable != nil && oldTable != nil {
oldDb := lib.GlobalDBSteward.OldDatabase
newDb := lib.GlobalDBSteward.NewDatabase
for _, oldConstraint := range GlobalConstraint.GetTableConstraints(oldDb, oldSchema, oldTable, constraintType) {
newConstraint := GlobalConstraint.TryGetTableConstraintNamed(newDb, newSchema, newTable, oldConstraint.Name, constraintType)
if newConstraint == nil || !newConstraint.Equals(oldConstraint) || GlobalConstraint.DependsOnRenamedTable(newDb, oldConstraint) || GlobalConstraint.DependsOnRenamedTable(newDb, newConstraint) {
for _, oldConstraint := range getTableConstraints(oldDb, oldSchema, oldTable, constraintType) {
newConstraint := tryGetTableConstraintNamed(newDb, newSchema, newTable, oldConstraint.Name, constraintType)
if newConstraint == nil || !newConstraint.Equals(oldConstraint) || constraintDependsOnRenamedTable(newDb, oldConstraint) || constraintDependsOnRenamedTable(newDb, newConstraint) {
out = append(out, oldConstraint)
}
}
Expand All @@ -89,9 +89,9 @@ func (self *DiffConstraints) GetNewConstraints(oldSchema *ir.Schema, oldTable *i
if newTable != nil {
oldDb := lib.GlobalDBSteward.OldDatabase
newDb := lib.GlobalDBSteward.NewDatabase
for _, newConstraint := range GlobalConstraint.GetTableConstraints(newDb, newSchema, newTable, constraintType) {
oldConstraint := GlobalConstraint.TryGetTableConstraintNamed(oldDb, oldSchema, oldTable, newConstraint.Name, constraintType)
if oldConstraint == nil || !oldConstraint.Equals(newConstraint) || GlobalConstraint.DependsOnRenamedTable(newDb, newConstraint) {
for _, newConstraint := range getTableConstraints(newDb, newSchema, newTable, constraintType) {
oldConstraint := tryGetTableConstraintNamed(oldDb, oldSchema, oldTable, newConstraint.Name, constraintType)
if oldConstraint == nil || !oldConstraint.Equals(newConstraint) || constraintDependsOnRenamedTable(newDb, newConstraint) {
out = append(out, newConstraint)
}
}
Expand Down
1 change: 0 additions & 1 deletion lib/format/pgsql8/pgsql8.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pgsql8
import "github.com/dbsteward/dbsteward/lib/format"

var GlobalOperations = NewOperations()
var GlobalConstraint = NewConstraint()
var GlobalFunction = NewFunction()
var GlobalIndex = NewIndex()
var GlobalLanguage = NewLanguage()
Expand Down

0 comments on commit af30da3

Please sign in to comment.