diff --git a/lib/format/pgsql8/column.go b/lib/format/pgsql8/column.go index b4feae2..25f0bbf 100644 --- a/lib/format/pgsql8/column.go +++ b/lib/format/pgsql8/column.go @@ -1,7 +1,6 @@ package pgsql8 import ( - "fmt" "strings" "github.com/dbsteward/dbsteward/lib" @@ -72,7 +71,7 @@ func getColumnSetupSql(schema *ir.Schema, table *ir.Table, column *ir.Column) [] } func getColumnDefaultSql(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql { - if !GlobalTable.IncludeColumnDefaultNextvalInCreateSql && hasDefaultNextval(column) { + if !includeColumnDefaultNextvalInCreateSql && hasDefaultNextval(column) { // if the default is a nextval expression, don't specify it in the regular full definition // because if the sequence has not been defined yet, // the nextval expression will be evaluated inline and fail @@ -155,22 +154,3 @@ func getReferenceType(coltype string) string { // TODO(feat) should this include enum types? return coltype } - -func getSerialStartDml(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql { - if column.SerialStart == nil { - return nil - } - if !isSerialType(column) { - lib.GlobalDBSteward.Fatal("Expected serial type for column %s.%s.%s because serialStart='%d' was defined, found type %s", - schema.Name, table.Name, column.Name, *column.SerialStart, column.Type) - } - return []output.ToSql{ - &sql.Annotated{ - Annotation: fmt.Sprintf("serialStart %d specified for %s.%s.%s", *column.SerialStart, schema.Name, table.Name, column.Name), - Wrapped: &sql.SequenceSerialSetVal{ - Column: sql.ColumnRef{Schema: schema.Name, Table: table.Name, Column: column.Name}, - Value: *column.SerialStart, - }, - }, - } -} diff --git a/lib/format/pgsql8/diff.go b/lib/format/pgsql8/diff.go index 37544e6..7a0b540 100644 --- a/lib/format/pgsql8/diff.go +++ b/lib/format/pgsql8/diff.go @@ -306,7 +306,7 @@ func (self *Diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 ou } for _, newGrant := range newTable.Grants { if oldTable == nil || !ir.HasPermissionsOf(oldTable, newGrant, ir.SqlFormatPgsql8) { - stage1.WriteSql(GlobalTable.GetGrantSql(newDoc, newSchema, newTable, newGrant)...) + stage1.WriteSql(getTableGrantSql(newDoc, newSchema, newTable, newGrant)...) } } } diff --git a/lib/format/pgsql8/diff_tables.go b/lib/format/pgsql8/diff_tables.go index c1ea5e3..3aeae5f 100644 --- a/lib/format/pgsql8/diff_tables.go +++ b/lib/format/pgsql8/diff_tables.go @@ -100,7 +100,7 @@ func (self *DiffTables) applyTableOptionsDiff(stage1 output.OutputFileSegmenter, if strings.EqualFold(entry.Key, "with") { // ALTER TABLE ... SET (params) doesn't accept oids=true/false unlike CREATE TABLE // only WITH OIDS or WITHOUT OIDS - params := GlobalTable.ParseStorageParams(entry.Value) + params := parseStorageParams(entry.Value) if oids, ok := params["oids"]; ok { delete(params, "oids") if util.IsTruthy(oids) { @@ -129,7 +129,7 @@ func (self *DiffTables) applyTableOptionsDiff(stage1 output.OutputFileSegmenter, for _, entry := range deleteOpts.Entries() { if strings.EqualFold(entry.Key, "with") { - params := GlobalTable.ParseStorageParams(entry.Value) + params := parseStorageParams(entry.Value) // handle oids separately since pgsql doesn't recognize it as a storage parameter in an ALTER TABLE if _, ok := params["oids"]; ok { delete(params, "oids") @@ -575,8 +575,8 @@ func (self *DiffTables) CreateTable(ofs output.OutputFileSegmenter, oldSchema, n }) } } else { - ofs.WriteSql(GlobalTable.GetCreationSql(newSchema, newTable)...) - ofs.WriteSql(GlobalTable.DefineTableColumnDefaults(newSchema, newTable)...) + ofs.WriteSql(getCreateTableSql(newSchema, newTable)...) + ofs.WriteSql(defineTableColumnDefaults(newSchema, newTable)...) } return nil } @@ -604,7 +604,7 @@ func (self *DiffTables) DropTable(ofs output.OutputFileSegmenter, oldSchema *ir. } } - ofs.WriteSql(GlobalTable.GetDropSql(oldSchema, oldTable)...) + ofs.WriteSql(getDropTableSql(oldSchema, oldTable)...) } func (self *DiffTables) DiffClusters(ofs output.OutputFileSegmenter, oldSchema, newSchema *ir.Schema) { @@ -654,7 +654,7 @@ func (self *DiffTables) GetCreateDataSql(oldSchema *ir.Schema, oldTable *ir.Tabl if oldTable == nil { // if this is a fresh build, make sure serial starts are issued _after_ the hardcoded data inserts - out = append(out, GlobalTable.GetSerialStartDml(newSchema, newTable)...) + out = append(out, getSerialStartDml(newSchema, newTable, nil)...) return out } diff --git a/lib/format/pgsql8/operations.go b/lib/format/pgsql8/operations.go index 22e98b1..c044a9f 100644 --- a/lib/format/pgsql8/operations.go +++ b/lib/format/pgsql8/operations.go @@ -881,20 +881,20 @@ func buildSchema(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep [] // table structure creation for _, schema := range doc.Schemas { // create defined tables - GlobalTable.IncludeColumnDefaultNextvalInCreateSql = false + includeColumnDefaultNextvalInCreateSql = false for _, table := range schema.Tables { // table definition - ofs.WriteSql(GlobalTable.GetCreationSql(schema, table)...) + ofs.WriteSql(getCreateTableSql(schema, table)...) // table indexes GlobalDiffIndexes.DiffIndexesTable(ofs, nil, nil, schema, table) // table grants for _, grant := range table.Grants { - ofs.WriteSql(GlobalTable.GetGrantSql(doc, schema, table, grant)...) + ofs.WriteSql(getTableGrantSql(doc, schema, table, grant)...) } } - GlobalTable.IncludeColumnDefaultNextvalInCreateSql = true + includeColumnDefaultNextvalInCreateSql = true // sequences contained in the schema for _, sequence := range schema.Sequences { @@ -909,7 +909,7 @@ func buildSchema(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep [] // add table nextvals that were omitted for _, table := range schema.Tables { if table.HasDefaultNextVal() { - ofs.WriteSql(GlobalTable.GetDefaultNextvalSql(schema, table)...) + ofs.WriteSql(getDefaultNextvalSql(schema, table)...) } } } @@ -933,7 +933,7 @@ func buildSchema(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep [] for _, schema := range doc.Schemas { for _, table := range schema.Tables { // TODO(go,nth) method name consistency - should be GetColumnDefaultsSql? - ofs.WriteSql(GlobalTable.DefineTableColumnDefaults(schema, table)...) + ofs.WriteSql(defineTableColumnDefaults(schema, table)...) } } diff --git a/lib/format/pgsql8/pgsql8.go b/lib/format/pgsql8/pgsql8.go index 125c07c..5549274 100644 --- a/lib/format/pgsql8/pgsql8.go +++ b/lib/format/pgsql8/pgsql8.go @@ -4,7 +4,6 @@ import "github.com/dbsteward/dbsteward/lib/format" var GlobalOperations = NewOperations() var GlobalSchema = NewSchema() -var GlobalTable = NewTable() var GlobalTrigger = NewTrigger() var GlobalDataType = NewDataType() var GlobalView = NewView() diff --git a/lib/format/pgsql8/table.go b/lib/format/pgsql8/table.go index 6d36d60..a183811 100644 --- a/lib/format/pgsql8/table.go +++ b/lib/format/pgsql8/table.go @@ -1,6 +1,7 @@ package pgsql8 import ( + "fmt" "strings" "github.com/dbsteward/dbsteward/lib/ir" @@ -11,15 +12,9 @@ import ( "github.com/dbsteward/dbsteward/lib/output" ) -type Table struct { - IncludeColumnDefaultNextvalInCreateSql bool -} - -func NewTable() *Table { - return &Table{} -} +var includeColumnDefaultNextvalInCreateSql bool -func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.ToSql { +func getCreateTableSql(schema *ir.Schema, table *ir.Table) []output.ToSql { cols := []sql.ColumnDefinition{} colSetup := []output.ToSql{} for _, col := range table.Columns { @@ -30,7 +25,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T opts := []sql.TableCreateOption{} for _, opt := range table.TableOptions { if opt.SqlFormat == ir.SqlFormatPgsql8 { - opts = append(opts, sql.TableCreateOption{opt.Name, opt.Value}) + opts = append(opts, sql.TableCreateOption{Option: opt.Name, Value: opt.Value}) } } @@ -45,7 +40,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T ddl := []output.ToSql{ &sql.TableCreate{ - Table: sql.TableRef{schema.Name, table.Name}, + Table: sql.TableRef{Schema: schema.Name, Table: table.Name}, Columns: cols, Inherits: inherits, OtherOptions: opts, @@ -54,7 +49,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T if table.Description != "" { ddl = append(ddl, &sql.TableSetComment{ - Table: sql.TableRef{schema.Name, table.Name}, + Table: sql.TableRef{Schema: schema.Name, Table: table.Name}, Comment: table.Description, }) } @@ -64,7 +59,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T if table.Owner != "" { role := lib.GlobalXmlParser.RoleEnum(lib.GlobalDBSteward.NewDatabase, table.Owner) ddl = append(ddl, &sql.TableAlterOwner{ - Table: sql.TableRef{schema.Name, table.Name}, + Table: sql.TableRef{Schema: schema.Name, Table: table.Name}, Role: role, }) @@ -74,7 +69,7 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T if isSerialType(col) { ident := buildSequenceName(schema.Name, table.Name, col.Name) ddl = append(ddl, &sql.TableAlterOwner{ - Table: sql.TableRef{schema.Name, ident}, + Table: sql.TableRef{Schema: schema.Name, Table: ident}, Role: role, }) } @@ -84,22 +79,22 @@ func (self *Table) GetCreationSql(schema *ir.Schema, table *ir.Table) []output.T return ddl } -func (self *Table) GetDropSql(schema *ir.Schema, table *ir.Table) []output.ToSql { +func getDropTableSql(schema *ir.Schema, table *ir.Table) []output.ToSql { return []output.ToSql{ &sql.TableDrop{ - Table: sql.TableRef{schema.Name, table.Name}, + Table: sql.TableRef{Schema: schema.Name, Table: table.Name}, }, } } -func (self *Table) GetDefaultNextvalSql(schema *ir.Schema, table *ir.Table) []output.ToSql { +func getDefaultNextvalSql(schema *ir.Schema, table *ir.Table) []output.ToSql { out := []output.ToSql{} for _, column := range table.Columns { if hasDefaultNextval(column) { lib.GlobalDBSteward.Info("Specifying skipped %s.%s.%s default expression \"%s\"", schema.Name, table.Name, column.Name, column.Default) out = append(out, &sql.Annotated{ Wrapped: &sql.ColumnSetDefault{ - Column: sql.ColumnRef{schema.Name, table.Name, column.Name}, + Column: sql.ColumnRef{Schema: schema.Name, Table: table.Name, Column: column.Name}, Default: sql.RawSql(column.Default), }, Annotation: "column default nextval expression being added post table creation", @@ -109,7 +104,7 @@ func (self *Table) GetDefaultNextvalSql(schema *ir.Schema, table *ir.Table) []ou return out } -func (self *Table) DefineTableColumnDefaults(schema *ir.Schema, table *ir.Table) []output.ToSql { +func defineTableColumnDefaults(schema *ir.Schema, table *ir.Table) []output.ToSql { out := []output.ToSql{} for _, column := range table.Columns { out = append(out, getColumnDefaultSql(schema, table, column)...) @@ -117,7 +112,7 @@ func (self *Table) DefineTableColumnDefaults(schema *ir.Schema, table *ir.Table) return out } -func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir.Table, grant *ir.Grant) []output.ToSql { +func getTableGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir.Table, grant *ir.Grant) []output.ToSql { roles := make([]string, len(grant.Roles)) for i, role := range grant.Roles { roles[i] = lib.GlobalXmlParser.RoleEnum(lib.GlobalDBSteward.NewDatabase, role) @@ -134,7 +129,7 @@ func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir. ddl := []output.ToSql{ &sql.TableGrant{ - Table: sql.TableRef{schema.Name, table.Name}, + Table: sql.TableRef{Schema: schema.Name, Table: table.Name}, Perms: []string(grant.Permissions), Roles: roles, }, @@ -146,7 +141,7 @@ func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir. roRole := lib.GlobalXmlParser.RoleEnum(lib.GlobalDBSteward.NewDatabase, ir.RoleReadOnly) if roRole != "" { ddl = append(ddl, &sql.TableGrant{ - Table: sql.TableRef{schema.Name, table.Name}, + Table: sql.TableRef{Schema: schema.Name, Table: table.Name}, Perms: []string{ir.PermissionSelect}, Roles: []string{roRole}, CanGrant: false, @@ -206,14 +201,36 @@ func (self *Table) GetGrantSql(doc *ir.Definition, schema *ir.Schema, table *ir. return ddl } -func (self *Table) GetSerialStartDml(schema *ir.Schema, table *ir.Table) []output.ToSql { - out := []output.ToSql{} - for _, column := range table.Columns { - out = append(out, getSerialStartDml(schema, table, column)...) +func getSerialStartDml(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql { + if column == nil { + out := []output.ToSql{} + for _, column := range table.Columns { + out = append(out, getSerialStartDml(schema, table, column)...) + } + return out + } + return _getSerialStartDml(schema, table, column) +} + +func _getSerialStartDml(schema *ir.Schema, table *ir.Table, column *ir.Column) []output.ToSql { + if column.SerialStart == nil { + return nil + } + if !isSerialType(column) { + lib.GlobalDBSteward.Fatal("Expected serial type for column %s.%s.%s because serialStart='%d' was defined, found type %s", + schema.Name, table.Name, column.Name, *column.SerialStart, column.Type) + } + return []output.ToSql{ + &sql.Annotated{ + Annotation: fmt.Sprintf("serialStart %d specified for %s.%s.%s", *column.SerialStart, schema.Name, table.Name, column.Name), + Wrapped: &sql.SequenceSerialSetVal{ + Column: sql.ColumnRef{Schema: schema.Name, Table: table.Name, Column: column.Name}, + Value: *column.SerialStart, + }, + }, } - return out } -func (self *Table) ParseStorageParams(value string) map[string]string { +func parseStorageParams(value string) map[string]string { return util.ParseKV(value[1:len(value)-1], ",", "=") } diff --git a/lib/format/pgsql8/table_test.go b/lib/format/pgsql8/table_test.go index ad1dc3c..77d4171 100644 --- a/lib/format/pgsql8/table_test.go +++ b/lib/format/pgsql8/table_test.go @@ -1,9 +1,8 @@ -package pgsql8_test +package pgsql8 import ( "testing" - "github.com/dbsteward/dbsteward/lib/format/pgsql8" "github.com/dbsteward/dbsteward/lib/format/pgsql8/sql" "github.com/dbsteward/dbsteward/lib/ir" "github.com/dbsteward/dbsteward/lib/output" @@ -43,7 +42,7 @@ func TestTable_GetCreationSql_TableOptions(t *testing.T) { }, } - ddl := pgsql8.GlobalTable.GetCreationSql(schema, schema.Tables[0]) + ddl := getCreateTableSql(schema, schema.Tables[0]) assert.Equal(t, []output.ToSql{ &sql.TableCreate{ Table: sql.TableRef{"public", "test"},