Skip to content

Commit

Permalink
Remove GlobalFunctions, refactor it into standalone functions, and so…
Browse files Browse the repository at this point in the history
…me general cleanup of that code
  • Loading branch information
williammoran committed Apr 25, 2024
1 parent af30da3 commit ae512a2
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 30 deletions.
2 changes: 1 addition & 1 deletion lib/format/pgsql8/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func (self *Diff) updatePermissions(stage1 output.OutputFileSegmenter, stage3 ou
oldFunc := oldSchema.TryGetFunctionMatching(newFunc)
for _, newGrant := range newFunc.Grants {
if oldFunc == nil || !ir.HasPermissionsOf(oldFunc, newGrant, ir.SqlFormatPgsql8) {
stage1.WriteSql(GlobalFunction.GetGrantSql(newDoc, newSchema, newFunc, newGrant)...)
stage1.WriteSql(getFunctionGrantSql(newSchema, newFunc, newGrant)...)
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions lib/format/pgsql8/diff_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (self *DiffFunctions) DiffFunctions(stage1 output.OutputFileSegmenter, stag
if oldSchema != nil {
for _, oldFunction := range oldSchema.Functions {
if newSchema.TryGetFunctionMatching(oldFunction) == nil {
stage3.WriteSql(GlobalFunction.GetDropSql(oldSchema, oldFunction)...)
stage3.WriteSql(getFunctionDropSql(oldSchema, oldFunction)...)
}
}
}
Expand All @@ -27,16 +27,16 @@ func (self *DiffFunctions) DiffFunctions(stage1 output.OutputFileSegmenter, stag
for _, newFunction := range newSchema.Functions {
oldFunction := oldSchema.TryGetFunctionMatching(newFunction)
if oldFunction == nil || !oldFunction.Equals(newFunction, ir.SqlFormatPgsql8) {
stage1.WriteSql(GlobalFunction.GetCreationSql(newSchema, newFunction)...)
stage1.WriteSql(getFunctionCreationSql(newSchema, newFunction)...)
} else if newFunction.ForceRedefine {
stage1.WriteSql(sql.NewComment("Function %s.%s has forceRedefine set to true", newSchema.Name, newFunction.Name))
stage1.WriteSql(GlobalFunction.GetCreationSql(newSchema, newFunction)...)
stage1.WriteSql(getFunctionCreationSql(newSchema, newFunction)...)
} else {
oldReturnType := oldSchema.TryGetTypeNamed(newFunction.Returns)
newReturnType := newSchema.TryGetTypeNamed(newFunction.Returns)
if oldReturnType != nil && newReturnType != nil && !oldReturnType.Equals(newReturnType) {
stage1.WriteSql(sql.NewComment("Function %s.%s return type %s has changed", newSchema.Name, newFunction.Name, newReturnType.Name))
stage1.WriteSql(GlobalFunction.GetCreationSql(newSchema, newFunction)...)
stage1.WriteSql(getFunctionCreationSql(newSchema, newFunction)...)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions lib/format/pgsql8/diff_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (self *DiffTypes) DiffTypes(ofs output.OutputFileSegmenter, oldSchema *ir.S
"Type migration of %s.%s requires recreating dependent function %s.%s",
newSchema.Name, newType.Name, oldSchema.Name, oldFunc.Name,
))
ofs.WriteSql(GlobalFunction.GetDropSql(oldSchema, oldFunc)...)
ofs.WriteSql(getFunctionDropSql(oldSchema, oldFunc)...)
}

columns, sql := GlobalDataType.AlterColumnTypePlaceholder(oldSchema, oldType)
Expand All @@ -59,7 +59,7 @@ func (self *DiffTypes) DiffTypes(ofs output.OutputFileSegmenter, oldSchema *ir.S

// functions are only recreated if they changed elsewise, so need to create them here
for _, newFunc := range GlobalSchema.GetFunctionsDependingOnType(newSchema, newType) {
ofs.WriteSql(GlobalFunction.GetCreationSql(newSchema, newFunc)...)
ofs.WriteSql(getFunctionCreationSql(newSchema, newFunc)...)
}

ofs.WriteSql(GlobalDataType.AlterColumnTypeRestore(columns, newSchema, newType)...)
Expand Down
33 changes: 15 additions & 18 deletions lib/format/pgsql8/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@ import (
"github.com/dbsteward/dbsteward/lib/util"
)

type Function struct {
IncludeColumnDefaultNextvalInCreateSql bool
}

func NewFunction() *Function {
return &Function{}
}

func (self *Function) DefinitionReferencesTable(definition *ir.FunctionDefinition) *lib.QualifiedTable {
func functionDefinitionReferencesTable(definition *ir.FunctionDefinition) *lib.QualifiedTable {
// TODO(feat) a function could reference many tables, but this only returns the first; make it understand many tables
// TODO(feat) this won't detect quoted table names
// TODO(go,pgsql) test this
Expand All @@ -42,8 +34,8 @@ func (self *Function) DefinitionReferencesTable(definition *ir.FunctionDefinitio
return &parsed
}

func (self *Function) GetCreationSql(schema *ir.Schema, function *ir.Function) []output.ToSql {
ref := sql.FunctionRef{schema.Name, function.Name, function.ParamSigs()}
func getFunctionCreationSql(schema *ir.Schema, function *ir.Function) []output.ToSql {
ref := sql.FunctionRef{Schema: schema.Name, Function: function.Name, Params: function.ParamSigs()}
def := function.TryGetDefinition(ir.SqlFormatPgsql8)
out := []output.ToSql{
&sql.FunctionCreate{
Expand Down Expand Up @@ -72,27 +64,32 @@ func (self *Function) GetCreationSql(schema *ir.Schema, function *ir.Function) [
return out
}

func (self *Function) GetDropSql(schema *ir.Schema, function *ir.Function) []output.ToSql {
func getFunctionDropSql(schema *ir.Schema, function *ir.Function) []output.ToSql {
types := function.ParamTypes()
for i, paramType := range types {
// TODO(feat) there's evidence in get_drop_sql that postgres only recognizes the normalized typenames here.
// we should look for other cases and validate behavior
types[i] = self.normalizeParameterType(paramType)
types[i] = normalizeFunctionParameterType(paramType)
}

return []output.ToSql{
&sql.FunctionDrop{sql.FunctionRef{schema.Name, function.Name, types}},
&sql.FunctionDrop{
Function: sql.FunctionRef{
Schema: schema.Name,
Function: function.Name,
Params: types,
}},
}
}

func (self *Function) normalizeParameterType(paramType string) string {
func normalizeFunctionParameterType(paramType string) string {
if strings.EqualFold(paramType, "character varying") || strings.EqualFold(paramType, "varying") {
return "varchar"
}
return paramType
}

func (self *Function) GetGrantSql(doc *ir.Definition, schema *ir.Schema, fn *ir.Function, grant *ir.Grant) []output.ToSql {
func getFunctionGrantSql(schema *ir.Schema, fn *ir.Function, grant *ir.Grant) []output.ToSql {
roles := make([]string, len(grant.Roles))
for i, role := range grant.Roles {
roles[i] = lib.GlobalXmlParser.RoleEnum(lib.GlobalDBSteward.NewDatabase, role)
Expand All @@ -109,7 +106,7 @@ func (self *Function) GetGrantSql(doc *ir.Definition, schema *ir.Schema, fn *ir.

ddl := []output.ToSql{
&sql.FunctionGrant{
Function: sql.FunctionRef{schema.Name, fn.Name, fn.ParamTypes()},
Function: sql.FunctionRef{Schema: schema.Name, Function: fn.Name, Params: fn.ParamTypes()},
Perms: []string(grant.Permissions),
Roles: roles,
CanGrant: grant.CanGrant(),
Expand All @@ -122,7 +119,7 @@ func (self *Function) GetGrantSql(doc *ir.Definition, schema *ir.Schema, fn *ir.
}

// TODO(go,3) move this to model
func (self *Function) FunctionDependsOnType(fn *ir.Function, typeSchema *ir.Schema, datatype *ir.DataType) bool {
func functionDependsOnType(fn *ir.Function, typeSchema *ir.Schema, datatype *ir.DataType) bool {
// TODO(feat) what about composite/domain types that are also dependent on the type? further refinement needed
qualifiedName := typeSchema.Name + "." + datatype.Name
returns := strings.TrimRight(fn.Returns, "[] ") // allow for arrays
Expand Down
6 changes: 3 additions & 3 deletions lib/format/pgsql8/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ outer:
for _, function := range schema.Functions {
if definition := function.TryGetDefinition(ir.SqlFormatPgsql8); definition != nil {
if strings.EqualFold(definition.Language, "sql") {
referenced := GlobalFunction.DefinitionReferencesTable(definition)
referenced := functionDefinitionReferencesTable(definition)
if referenced == nil {
continue
}
Expand Down Expand Up @@ -918,12 +918,12 @@ func buildSchema(doc *ir.Definition, ofs output.OutputFileSegmenter, tableDep []
for _, schema := range doc.Schemas {
for _, function := range schema.Functions {
if function.HasDefinition(ir.SqlFormatPgsql8) {
ofs.WriteSql(GlobalFunction.GetCreationSql(schema, function)...)
ofs.WriteSql(getFunctionCreationSql(schema, function)...)
// when pg:build_schema() is doing its thing for straight builds, include function permissions
// they are not included in pg_function::get_creation_sql()

for _, grant := range function.Grants {
ofs.WriteSql(GlobalFunction.GetGrantSql(doc, schema, function, grant)...)
ofs.WriteSql(getFunctionGrantSql(schema, function, grant)...)
}
}
}
Expand Down
1 change: 0 additions & 1 deletion lib/format/pgsql8/pgsql8.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pgsql8
import "github.com/dbsteward/dbsteward/lib/format"

var GlobalOperations = NewOperations()
var GlobalFunction = NewFunction()
var GlobalIndex = NewIndex()
var GlobalLanguage = NewLanguage()
var GlobalPermission = NewPermission()
Expand Down
2 changes: 1 addition & 1 deletion lib/format/pgsql8/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (self *Schema) GetGrantSql(doc *ir.Definition, schema *ir.Schema, grant *ir
func (self *Schema) GetFunctionsDependingOnType(schema *ir.Schema, datatype *ir.DataType) []*ir.Function {
out := []*ir.Function{}
for _, fn := range schema.Functions {
if GlobalFunction.FunctionDependsOnType(fn, schema, datatype) {
if functionDependsOnType(fn, schema, datatype) {
out = append(out, fn)
}
}
Expand Down

0 comments on commit ae512a2

Please sign in to comment.