Skip to content

Commit

Permalink
fix(expressions): Use compound names in protobufs (#24)
Browse files Browse the repository at this point in the history
* Add expression plan roundtrip test

* Use compound names for variants, and keep simple and compound names separate

* Do not add the same URI twice

* Update tests - protobufs should refer to compound names

* Output extensions in anchor order

* Use type parser to parse types
  • Loading branch information
wackywendell committed Aug 17, 2023
1 parent 082cc2b commit f120601
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 62 deletions.
2 changes: 0 additions & 2 deletions expr/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ func ExprFromProto(e *proto.Expression, baseSchema types.Type, reg ExtensionRegi
return &ScalarFunction{
funcRef: et.ScalarFunction.FunctionReference,
declaration: decl,
id: id,
args: args,
options: et.ScalarFunction.Options,
outputType: types.TypeFromProto(et.ScalarFunction.OutputType),
Expand Down Expand Up @@ -126,7 +125,6 @@ func ExprFromProto(e *proto.Expression, baseSchema types.Type, reg ExtensionRegi

return &WindowFunction{
funcRef: et.WindowFunction.FunctionReference,
id: id,
declaration: decl,
args: args,
options: et.WindowFunction.Options,
Expand Down
20 changes: 10 additions & 10 deletions expr/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func ExampleExpression_scalarFunction() {
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 2,
"name": "add"
"name": "add:i32_i32"
}
}
],
Expand Down Expand Up @@ -105,7 +105,7 @@ func ExampleExpression_scalarFunction() {
// having to construct the protobuf
const substraitext = `https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml`

var addVariant = ext.NewScalarFuncVariant(ext.ID{URI: substraitext, Name: "add"})
var addVariant = ext.NewScalarFuncVariant(ext.ID{URI: substraitext, Name: "add:i32_i32"})

var ex expr.Expression
refArg, _ := expr.NewRootFieldRef(expr.NewStructFieldRef(0), &types.StructType{Types: []types.Type{&types.Int32Type{}}})
Expand Down Expand Up @@ -186,28 +186,28 @@ func TestExpressionsRoundtrip(t *testing.T) {
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 2,
"name": "add"
"name": "add:fp64_fp64"
}
},
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 3,
"name": "subtract"
"name": "subtract:fp32_fp32"
}
},
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 4,
"name": "multiply"
"name": "multiply:i64_i64"
}
},
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 5,
"name": "ntile"
"name": "ntile:"
}
}
],
Expand Down Expand Up @@ -292,28 +292,28 @@ func TestRoundTripUsingTestData(t *testing.T) {
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 2,
"name": "add"
"name": "add:fp64_fp64"
}
},
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 3,
"name": "subtract"
"name": "subtract:fp64_fp64"
}
},
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 4,
"name": "multiply"
"name": "multiply:fp64_fp64"
}
},
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 5,
"name": "ntile"
"name": "ntile:i32"
}
}
],
Expand Down
58 changes: 30 additions & 28 deletions expr/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ func BoundFromProto(b *proto.Expression_WindowFunction_Bound) Bound {

type ScalarFunction struct {
funcRef uint32
id extensions.ID
declaration *extensions.ScalarFunctionVariant

args []types.FuncArg
Expand All @@ -163,10 +162,8 @@ func NewCustomScalarFunc(reg ExtensionRegistry, v *extensions.ScalarFunctionVari
return nil, fmt.Errorf("%w: must provide non-nil output type", substraitgo.ErrInvalidType)
}

id := extensions.ID{URI: v.URI(), Name: v.Name()}
return &ScalarFunction{
funcRef: reg.GetFuncAnchor(id),
id: id,
funcRef: reg.GetFuncAnchor(v.ID()),
declaration: v,
options: opts,
args: args,
Expand Down Expand Up @@ -246,18 +243,22 @@ func NewScalarFunc(reg ExtensionRegistry, id extensions.ID, opts []*types.Functi
return nil, err
}

// We use the fully qualified ID for resolving an anchor, to make sure we
// are using the correct compound name
ref := reg.GetFuncAnchor(decl.ID())

return &ScalarFunction{
funcRef: reg.GetFuncAnchor(id),
id: id,
funcRef: ref,
declaration: decl,
outputType: outType,
options: opts,
args: args,
}, nil
}

func (s *ScalarFunction) Name() string { return s.declaration.CompoundName() }
func (s *ScalarFunction) ID() extensions.ID { return s.id }
func (s *ScalarFunction) Name() string { return s.declaration.Name() }
func (s *ScalarFunction) CompoundName() string { return s.declaration.CompoundName() }
func (s *ScalarFunction) ID() extensions.ID { return s.declaration.ID() }
func (s *ScalarFunction) Variadic() *extensions.VariadicBehavior { return s.declaration.Variadic() }
func (s *ScalarFunction) SessionDependant() bool { return s.declaration.SessionDependent() }
func (s *ScalarFunction) Deterministic() bool { return s.declaration.Deterministic() }
Expand All @@ -280,7 +281,7 @@ func (*ScalarFunction) isRootRef() {}
func (s *ScalarFunction) String() string {
var b strings.Builder

b.WriteString(s.id.Name)
b.WriteString(s.Name())
b.WriteByte('(')

for i, arg := range s.args {
Expand Down Expand Up @@ -408,7 +409,6 @@ func (s *ScalarFunction) Visit(visit VisitFunc) Expression {

type WindowFunction struct {
funcRef uint32
id extensions.ID
declaration *extensions.WindowFunctionVariant

args []types.FuncArg
Expand All @@ -428,11 +428,9 @@ func NewCustomWindowFunc(reg ExtensionRegistry, v *extensions.WindowFunctionVari
return nil, fmt.Errorf("%w: must provide non-nil output type", substraitgo.ErrInvalidExpr)
}

id := extensions.ID{URI: v.URI(), Name: v.Name()}
return &WindowFunction{
funcRef: reg.GetFuncAnchor(id),
funcRef: reg.GetFuncAnchor(v.ID()),
declaration: v,
id: id,
outputType: outputType,
options: opts,
args: args,
Expand All @@ -452,9 +450,12 @@ func NewWindowFunc(reg ExtensionRegistry, id extensions.ID, opts []*types.Functi
substraitgo.ErrInvalidExpr, id)
}

// We use the fully qualified ID for resolving an anchor, to make sure we
// are using the correct compound name
ref := reg.GetFuncAnchor(decl.ID())

return &WindowFunction{
funcRef: reg.GetFuncAnchor(id),
id: id,
funcRef: ref,
declaration: decl,
outputType: outType,
options: opts,
Expand All @@ -464,8 +465,9 @@ func NewWindowFunc(reg ExtensionRegistry, id extensions.ID, opts []*types.Functi
}, nil
}

func (w *WindowFunction) Name() string { return w.declaration.CompoundName() }
func (w *WindowFunction) ID() extensions.ID { return w.id }
func (w *WindowFunction) Name() string { return w.declaration.Name() }
func (w *WindowFunction) CompoundName() string { return w.declaration.CompoundName() }
func (w *WindowFunction) ID() extensions.ID { return w.declaration.ID() }
func (w *WindowFunction) Variadic() *extensions.VariadicBehavior { return w.declaration.Variadic() }
func (w *WindowFunction) SessionDependant() bool { return w.declaration.SessionDependent() }
func (w *WindowFunction) Deterministic() bool { return w.declaration.Deterministic() }
Expand All @@ -487,7 +489,7 @@ func (*WindowFunction) isRootRef() {}
func (w *WindowFunction) String() string {
var b strings.Builder

b.WriteString(w.id.Name)
b.WriteString(w.declaration.Name())
b.WriteByte('(')

for i, arg := range w.args {
Expand Down Expand Up @@ -667,7 +669,6 @@ func (w *WindowFunction) Visit(visit VisitFunc) Expression {

type AggregateFunction struct {
funcRef uint32
id extensions.ID
declaration *extensions.AggregateFunctionVariant

args []types.FuncArg
Expand All @@ -684,9 +685,12 @@ func NewAggregateFunc(reg ExtensionRegistry, id extensions.ID, opts []*types.Fun
return nil, err
}

// We use the fully qualified ID for resolving an anchor, to make sure we
// are using the correct compound name
ref := reg.GetFuncAnchor(decl.ID())

return &AggregateFunction{
funcRef: reg.GetFuncAnchor(id),
id: id,
funcRef: ref,
declaration: decl,
outputType: outType,
options: opts,
Expand All @@ -702,10 +706,8 @@ func NewCustomAggregateFunc(reg ExtensionRegistry, v *extensions.AggregateFuncti
return nil, fmt.Errorf("%w: must provide non-nil output type", substraitgo.ErrInvalidExpr)
}

id := extensions.ID{URI: v.URI(), Name: v.Name()}
return &AggregateFunction{
funcRef: reg.GetFuncAnchor(id),
id: id,
funcRef: reg.GetFuncAnchor(v.ID()),
outputType: outputType,
options: opts,
args: args,
Expand Down Expand Up @@ -746,7 +748,6 @@ func NewAggregateFunctionFromProto(agg *proto.AggregateFunction, baseSchema type

return &AggregateFunction{
funcRef: agg.FunctionReference,
id: id,
declaration: decl,
args: args,
options: agg.Options,
Expand All @@ -757,8 +758,9 @@ func NewAggregateFunctionFromProto(agg *proto.AggregateFunction, baseSchema type
}, nil
}

func (a *AggregateFunction) Name() string { return a.declaration.CompoundName() }
func (a *AggregateFunction) ID() extensions.ID { return a.id }
func (a *AggregateFunction) Name() string { return a.declaration.Name() }
func (a *AggregateFunction) CompoundName() string { return a.declaration.CompoundName() }
func (a *AggregateFunction) ID() extensions.ID { return a.declaration.ID() }
func (a *AggregateFunction) Variadic() *extensions.VariadicBehavior { return a.declaration.Variadic() }
func (a *AggregateFunction) SessionDependant() bool { return a.declaration.SessionDependent() }
func (a *AggregateFunction) Deterministic() bool { return a.declaration.Deterministic() }
Expand All @@ -778,7 +780,7 @@ func (a *AggregateFunction) IntermediateType() (types.Type, error) {
func (a *AggregateFunction) String() string {
var b strings.Builder

b.WriteString(a.id.Name)
b.WriteString(a.declaration.Name())
b.WriteByte('(')

for i, arg := range a.args {
Expand Down
14 changes: 7 additions & 7 deletions expr/testdata/extended_exprs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ tests:
- extensionFunction:
extensionUriReference: 1
functionAnchor: 2
name: add
name: add:i64_i64
- extensionFunction:
extensionUriReference: 1
functionAnchor: 3
name: subtract
name: subtract:i64_i64
- extensionFunction:
extensionUriReference: 1
functionAnchor: 4
name: multiply
name: multiply:i64_i64
- extensionFunction:
extensionUriReference: 1
functionAnchor: 5
name: ntile
name: ntile:i64_i64
- extensionFunction:
extensionUriReference: 1
functionAnchor: 6
name: sum
name: sum:i64
baseSchema:
names: [a, b, c, d]
struct:
Expand All @@ -50,13 +50,13 @@ tests:
expression:
scalarFunction:
functionReference: 2
arguments:
arguments:
- value:
selection:
rootReference: {}
directReference: { structField: { field: 1 }}
- value:
selection:
rootReference: {}
directReference: { structField: { field: 2 }}
directReference: { structField: { field: 2 }}
outputType: { i32: {} }
Loading

0 comments on commit f120601

Please sign in to comment.