Skip to content

Commit

Permalink
chore(extensions): Minor refactoring in extension_mgr.go (#45)
Browse files Browse the repository at this point in the history
* Minor refactoring in extension_mgr.go
  • Loading branch information
scgkiran committed Aug 19, 2024
1 parent 5556c23 commit cbd28cb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
28 changes: 13 additions & 15 deletions extensions/extension_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func init() {
if err != nil {
panic(err)
}
_ = f.Close()
if err := f.Close(); err != nil {
panic(err)
}
}
}

Expand Down Expand Up @@ -203,27 +205,23 @@ func (c *Collection) URILoaded(uri string) bool {
}

func (c *Collection) GetAllScalarFunctions() []*ScalarFunctionVariant {
scalarFunctions := make([]*ScalarFunctionVariant, 0, len(c.scalarMap))
for _, v := range c.scalarMap {
scalarFunctions = append(scalarFunctions, v)
}
return scalarFunctions
return getValues(c.scalarMap)
}

func (c *Collection) GetAllAggregateFunctions() []*AggregateFunctionVariant {
aggregateFunctions := make([]*AggregateFunctionVariant, 0, len(c.aggregateMap))
for _, v := range c.aggregateMap {
aggregateFunctions = append(aggregateFunctions, v)
}
return aggregateFunctions
return getValues(c.aggregateMap)
}

func (c *Collection) GetAllWindowFunctions() []*WindowFunctionVariant {
windowFunctions := make([]*WindowFunctionVariant, 0, len(c.windowMap))
for _, v := range c.windowMap {
windowFunctions = append(windowFunctions, v)
return getValues(c.windowMap)
}

func getValues[M ~map[K]V, K comparable, V any](m M) []V {
result := make([]V, 0, len(m))
for _, v := range m {
result = append(result, v)
}
return windowFunctions
return result
}

type Set interface {
Expand Down
41 changes: 41 additions & 0 deletions extensions/extension_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,44 @@ func TestDefaultCollection(t *testing.T) {
assert.Equal(t, "point", et.Name)
assert.Equal(t, map[string]interface{}{"latitude": "i32", "longitude": "i32"}, et.Structure)
}

func TestCollection_GetAllScalarFunctions(t *testing.T) {
scalarFunctions := extensions.DefaultCollection.GetAllScalarFunctions()
aggregateFunctions := extensions.DefaultCollection.GetAllAggregateFunctions()
windowFunctions := extensions.DefaultCollection.GetAllWindowFunctions()
assert.GreaterOrEqual(t, len(scalarFunctions), 309)
assert.GreaterOrEqual(t, len(aggregateFunctions), 62)
assert.GreaterOrEqual(t, len(windowFunctions), 7)
tests := []struct {
uri string
signature string
isScalar bool
isAggregate bool
isWindow bool
}{
{extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", "add:i32_i32", true, false, false},
{extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", "variance:fp64", false, true, false},
{extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", "dense_rank:", false, false, true},
}
for _, tt := range tests {
t.Run(tt.signature, func(t *testing.T) {
assert.True(t, tt.isScalar || tt.isAggregate || tt.isWindow)
c := extensions.DefaultCollection
if tt.isScalar {
sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, scalarFunctions, sf)
}
if tt.isAggregate {
af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, aggregateFunctions, af)
}
if tt.isWindow {
wf, ok := c.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, windowFunctions, wf)
}
})
}
}

0 comments on commit cbd28cb

Please sign in to comment.