Skip to content

Commit

Permalink
Remove GlobalTable, refactor it into standalone functions, and some g…
Browse files Browse the repository at this point in the history
…eneral cleanup of that code
  • Loading branch information
williammoran committed Apr 25, 2024
1 parent cb8a532 commit 230eff5
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 65 deletions.
22 changes: 1 addition & 21 deletions lib/format/pgsql8/column.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package pgsql8

import (
"fmt"
"strings"

"github.com/dbsteward/dbsteward/lib"
Expand Down Expand Up @@ -72,7 +71,7 @@ func getColumnSetupSql(schema *ir.Schema, table *ir.Table, column *ir.Column) []
}

func getColumnDefaultSql(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql {
if !GlobalTable.IncludeColumnDefaultNextvalInCreateSql && hasDefaultNextval(column) {
if !includeColumnDefaultNextvalInCreateSql && hasDefaultNextval(column) {
// if the default is a nextval expression, don't specify it in the regular full definition
// because if the sequence has not been defined yet,
// the nextval expression will be evaluated inline and fail
Expand Down Expand Up @@ -155,22 +154,3 @@ func getReferenceType(coltype string) string {
// TODO(feat) should this include enum types?
return coltype
}

func getSerialStartDml(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql {
if column.SerialStart == nil {
return nil
}
if !isSerialType(column) {
lib.GlobalDBSteward.Fatal("Expected serial type for column %s.%s.%s because serialStart='%d' was defined, found type %s",
schema.Name, table.Name, column.Name, *column.SerialStart, column.Type)
}
return []output.ToSql{
&sql.Annotated{
Annotation: fmt.Sprintf("serialStart %d specified for %s.%s.%s", *column.SerialStart, schema.Name, table.Name, column.Name),
Wrapped: &sql.SequenceSerialSetVal{
Column: sql.ColumnRef{Schema: schema.Name, Table: table.Name, Column: column.Name},
Value: *column.SerialStart,
},
},
}
}
2 changes: 1 addition & 1 deletion lib/format/pgsql8/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func (self *Diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 ou
}
for _, newGrant := range newTable.Grants {
if oldTable == nil || !ir.HasPermissionsOf(oldTable, newGrant, ir.SqlFormatPgsql8) {
stage1.WriteSql(GlobalTable.GetGrantSql(newDoc, newSchema, newTable, newGrant)...)
stage1.WriteSql(getTableGrantSql(newDoc, newSchema, newTable, newGrant)...)
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions lib/format/pgsql8/diff_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (self *DiffTables) applyTableOptionsDiff(stage1 output.OutputFileSegmenter,
if strings.EqualFold(entry.Key, "with") {
// ALTER TABLE ... SET (params) doesn't accept oids=true/false unlike CREATE TABLE
// only WITH OIDS or WITHOUT OIDS
params := GlobalTable.ParseStorageParams(entry.Value)
params := parseStorageParams(entry.Value)
if oids, ok := params["oids"]; ok {
delete(params, "oids")
if util.IsTruthy(oids) {
Expand Down Expand Up @@ -129,7 +129,7 @@ func (self *DiffTables) applyTableOptionsDiff(stage1 output.OutputFileSegmenter,

for _, entry := range deleteOpts.Entries() {
if strings.EqualFold(entry.Key, "with") {
params := GlobalTable.ParseStorageParams(entry.Value)
params := parseStorageParams(entry.Value)
// handle oids separately since pgsql doesn't recognize it as a storage parameter in an ALTER TABLE
if _, ok := params["oids"]; ok {
delete(params, "oids")
Expand Down Expand Up @@ -575,8 +575,8 @@ func (self *DiffTables) CreateTable(ofs output.OutputFileSegmenter, oldSchema, n
})
}
} else {
ofs.WriteSql(GlobalTable.GetCreationSql(newSchema, newTable)...)
ofs.WriteSql(GlobalTable.DefineTableColumnDefaults(newSchema, newTable)...)
ofs.WriteSql(getCreateTableSql(newSchema, newTable)...)
ofs.WriteSql(defineTableColumnDefaults(newSchema, newTable)...)
}
return nil
}
Expand Down Expand Up @@ -604,7 +604,7 @@ func (self *DiffTables) DropTable(ofs output.OutputFileSegmenter, oldSchema *ir.
}
}

ofs.WriteSql(GlobalTable.GetDropSql(oldSchema, oldTable)...)
ofs.WriteSql(getDropTableSql(oldSchema, oldTable)...)
}

func (self *DiffTables) DiffClusters(ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) {
Expand Down Expand Up @@ -654,7 +654,7 @@ func (self *DiffTables) GetCreateDataSql(oldSchema *ir.Schema, oldTable *ir.Tabl

if oldTable == nil {
// if this is a fresh build, make sure serial starts are issued _after_ the hardcoded data inserts
out = append(out, GlobalTable.GetSerialStartDml(newSchema, newTable)...)
out = append(out, getSerialStartDml(newSchema, newTable, nil)...)
return out
}

Expand Down
12 changes: 6 additions & 6 deletions lib/format/pgsql8/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -881,20 +881,20 @@ func buildSchema(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []
// table structure creation
for _, schema := range doc.Schemas {
// create defined tables
GlobalTable.IncludeColumnDefaultNextvalInCreateSql = false
includeColumnDefaultNextvalInCreateSql = false
for _, table := range schema.Tables {
// table definition
ofs.WriteSql(GlobalTable.GetCreationSql(schema, table)...)
ofs.WriteSql(getCreateTableSql(schema, table)...)

// table indexes
GlobalDiffIndexes.DiffIndexesTable(ofs, nil, nil, schema, table)

// table grants
for _, grant := range table.Grants {
ofs.WriteSql(GlobalTable.GetGrantSql(doc, schema, table, grant)...)
ofs.WriteSql(getTableGrantSql(doc, schema, table, grant)...)
}
}
GlobalTable.IncludeColumnDefaultNextvalInCreateSql = true
includeColumnDefaultNextvalInCreateSql = true

// sequences contained in the schema
for _, sequence := range schema.Sequences {
Expand All @@ -909,7 +909,7 @@ func buildSchema(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []
// add table nextvals that were omitted
for _, table := range schema.Tables {
if table.HasDefaultNextVal() {
ofs.WriteSql(GlobalTable.GetDefaultNextvalSql(schema, table)...)
ofs.WriteSql(getDefaultNextvalSql(schema, table)...)
}
}
}
Expand All @@ -933,7 +933,7 @@ func buildSchema(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []
for _, schema := range doc.Schemas {
for _, table := range schema.Tables {
// TODO(go,nth) method name consistency - should be GetColumnDefaultsSql?
ofs.WriteSql(GlobalTable.DefineTableColumnDefaults(schema, table)...)
ofs.WriteSql(defineTableColumnDefaults(schema, table)...)
}
}

Expand Down
1 change: 0 additions & 1 deletion lib/format/pgsql8/pgsql8.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import "github.com/dbsteward/dbsteward/lib/format"

var GlobalOperations = NewOperations()
var GlobalSchema = NewSchema()
var GlobalTable = NewTable()
var GlobalTrigger = NewTrigger()
var GlobalDataType = NewDataType()
var GlobalView = NewView()
Expand Down
71 changes: 44 additions & 27 deletions lib/format/pgsql8/table.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pgsql8

import (
"fmt"
"strings"

"github.com/dbsteward/dbsteward/lib/ir"
Expand All @@ -11,15 +12,9 @@ import (
"github.com/dbsteward/dbsteward/lib/output"
)

type Table struct {
IncludeColumnDefaultNextvalInCreateSql bool
}

func NewTable() *Table {
return &Table{}
}
var includeColumnDefaultNextvalInCreateSql bool

func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.ToSql {
func getCreateTableSql(schema *ir.Schema, table *ir.Table) []output.ToSql {
cols := []sql.ColumnDefinition{}
colSetup := []output.ToSql{}
for _, col := range table.Columns {
Expand All @@ -30,7 +25,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T
opts := []sql.TableCreateOption{}
for _, opt := range table.TableOptions {
if opt.SqlFormat == ir.SqlFormatPgsql8 {
opts = append(opts, sql.TableCreateOption{opt.Name, opt.Value})
opts = append(opts, sql.TableCreateOption{Option: opt.Name, Value: opt.Value})
}
}

Expand All @@ -45,7 +40,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T

ddl := []output.ToSql{
&sql.TableCreate{
Table: sql.TableRef{schema.Name, table.Name},
Table: sql.TableRef{Schema: schema.Name, Table: table.Name},
Columns: cols,
Inherits: inherits,
OtherOptions: opts,
Expand All @@ -54,7 +49,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T

if table.Description != "" {
ddl = append(ddl, &sql.TableSetComment{
Table: sql.TableRef{schema.Name, table.Name},
Table: sql.TableRef{Schema: schema.Name, Table: table.Name},
Comment: table.Description,
})
}
Expand All @@ -64,7 +59,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T
if table.Owner != "" {
role := lib.GlobalXmlParser.RoleEnum(lib.GlobalDBSteward.NewDatabase, table.Owner)
ddl = append(ddl, &sql.TableAlterOwner{
Table: sql.TableRef{schema.Name, table.Name},
Table: sql.TableRef{Schema: schema.Name, Table: table.Name},
Role: role,
})

Expand All @@ -74,7 +69,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T
if isSerialType(col) {
ident := buildSequenceName(schema.Name, table.Name, col.Name)
ddl = append(ddl, &sql.TableAlterOwner{
Table: sql.TableRef{schema.Name, ident},
Table: sql.TableRef{Schema: schema.Name, Table: ident},
Role: role,
})
}
Expand All @@ -84,22 +79,22 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T
return ddl
}

func (self *Table) GetDropSql(schema *ir.Schema, table *ir.Table) []output.ToSql {
func getDropTableSql(schema *ir.Schema, table *ir.Table) []output.ToSql {
return []output.ToSql{
&sql.TableDrop{
Table: sql.TableRef{schema.Name, table.Name},
Table: sql.TableRef{Schema: schema.Name, Table: table.Name},
},
}
}

func (self *Table) GetDefaultNextvalSql(schema *ir.Schema, table *ir.Table) []output.ToSql {
func getDefaultNextvalSql(schema *ir.Schema, table *ir.Table) []output.ToSql {
out := []output.ToSql{}
for _, column := range table.Columns {
if hasDefaultNextval(column) {
lib.GlobalDBSteward.Info("Specifying skipped %s.%s.%s default expression \"%s\"", schema.Name, table.Name, column.Name, column.Default)
out = append(out, &sql.Annotated{
Wrapped: &sql.ColumnSetDefault{
Column: sql.ColumnRef{schema.Name, table.Name, column.Name},
Column: sql.ColumnRef{Schema: schema.Name, Table: table.Name, Column: column.Name},
Default: sql.RawSql(column.Default),
},
Annotation: "column default nextval expression being added post table creation",
Expand All @@ -109,15 +104,15 @@ func (self *Table) GetDefaultNextvalSql(schema *ir.Schema, table *ir.Table) []ou
return out
}

func (self *Table) DefineTableColumnDefaults(schema *ir.Schema, table *ir.Table) []output.ToSql {
func defineTableColumnDefaults(schema *ir.Schema, table *ir.Table) []output.ToSql {
out := []output.ToSql{}
for _, column := range table.Columns {
out = append(out, getColumnDefaultSql(schema, table, column)...)
}
return out
}

func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir.Table, grant *ir.Grant) []output.ToSql {
func getTableGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir.Table, grant *ir.Grant) []output.ToSql {
roles := make([]string, len(grant.Roles))
for i, role := range grant.Roles {
roles[i] = lib.GlobalXmlParser.RoleEnum(lib.GlobalDBSteward.NewDatabase, role)
Expand All @@ -134,7 +129,7 @@ func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir.

ddl := []output.ToSql{
&sql.TableGrant{
Table: sql.TableRef{schema.Name, table.Name},
Table: sql.TableRef{Schema: schema.Name, Table: table.Name},
Perms: []string(grant.Permissions),
Roles: roles,
},
Expand All @@ -146,7 +141,7 @@ func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir.
roRole := lib.GlobalXmlParser.RoleEnum(lib.GlobalDBSteward.NewDatabase, ir.RoleReadOnly)
if roRole != "" {
ddl = append(ddl, &sql.TableGrant{
Table: sql.TableRef{schema.Name, table.Name},
Table: sql.TableRef{Schema: schema.Name, Table: table.Name},
Perms: []string{ir.PermissionSelect},
Roles: []string{roRole},
CanGrant: false,
Expand Down Expand Up @@ -206,14 +201,36 @@ func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir.
return ddl
}

func (self *Table) GetSerialStartDml(schema *ir.Schema, table *ir.Table) []output.ToSql {
out := []output.ToSql{}
for _, column := range table.Columns {
out = append(out, getSerialStartDml(schema, table, column)...)
func getSerialStartDml(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql {
if column == nil {
out := []output.ToSql{}
for _, column := range table.Columns {
out = append(out, getSerialStartDml(schema, table, column)...)
}
return out
}
return _getSerialStartDml(schema, table, column)
}

func _getSerialStartDml(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql {
if column.SerialStart == nil {
return nil
}
if !isSerialType(column) {
lib.GlobalDBSteward.Fatal("Expected serial type for column %s.%s.%s because serialStart='%d' was defined, found type %s",
schema.Name, table.Name, column.Name, *column.SerialStart, column.Type)
}
return []output.ToSql{
&sql.Annotated{
Annotation: fmt.Sprintf("serialStart %d specified for %s.%s.%s", *column.SerialStart, schema.Name, table.Name, column.Name),
Wrapped: &sql.SequenceSerialSetVal{
Column: sql.ColumnRef{Schema: schema.Name, Table: table.Name, Column: column.Name},
Value: *column.SerialStart,
},
},
}
return out
}

func (self *Table) ParseStorageParams(value string) map[string]string {
func parseStorageParams(value string) map[string]string {
return util.ParseKV(value[1:len(value)-1], ",", "=")
}
5 changes: 2 additions & 3 deletions lib/format/pgsql8/table_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package pgsql8_test
package pgsql8

import (
"testing"

"github.com/dbsteward/dbsteward/lib/format/pgsql8"
"github.com/dbsteward/dbsteward/lib/format/pgsql8/sql"
"github.com/dbsteward/dbsteward/lib/ir"
"github.com/dbsteward/dbsteward/lib/output"
Expand Down Expand Up @@ -43,7 +42,7 @@ func TestTable_GetCreationSql_TableOptions(t *testing.T) {
},
}

ddl := pgsql8.GlobalTable.GetCreationSql(schema, schema.Tables[0])
ddl := getCreateTableSql(schema, schema.Tables[0])
assert.Equal(t, []output.ToSql{
&sql.TableCreate{
Table: sql.TableRef{"public", "test"},
Expand Down

0 comments on commit 230eff5

Please sign in to comment.