diff --git a/plan/builders.go b/plan/builders.go index f90574d..dbba96a 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -188,7 +188,7 @@ func (b *builder) ProjectRemap(input Rel, remap []int32, exprs ...expr.Expressio return nil, fmt.Errorf("%w: must provide at least one expression for project relation", substraitgo.ErrInvalidRel) } - noutput := int32(len(input.Remap(input.RecordType()).Types)) + noutput := int32(len(input.Remap(input.RecordType()).Types) + len(exprs)) for _, idx := range remap { if idx < 0 || idx >= noutput { return nil, errOutputMappingOutOfRange @@ -228,7 +228,7 @@ func (b *builder) AggregateColumnsRemap(input Rel, remap []int32, measures []Agg exprs[i] = []expr.Expression{ref} } - noutput := int32(len(input.Remap(input.RecordType()).Types)) + noutput := int32(len(measures) + len(groupByCols)) for _, idx := range remap { if idx < 0 || idx >= noutput { return nil, errOutputMappingOutOfRange @@ -263,7 +263,7 @@ func (b *builder) AggregateExprsRemap(input Rel, remap []int32, measures []AggRe return nil, fmt.Errorf("%w: groupings cannot contain empty expression list or nil expression", substraitgo.ErrInvalidRel) } - noutput := int32(len(input.Remap(input.RecordType()).Types)) + noutput := int32(len(measures) + len(groups)) for _, idx := range remap { if idx < 0 || idx >= noutput { return nil, errOutputMappingOutOfRange diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index 5cd0a58..d5c5ff9 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -163,8 +163,8 @@ func TestAggregateRelPlan(t *testing.T) { "groupings": [ { "groupingExpressions": [ - { - "selection": { + { + "selection": { "rootReference": {}, "directReference": { "structField": { "field": 0 }} } @@ -185,7 +185,7 @@ func TestAggregateRelPlan(t *testing.T) { "invocation": "AGGREGATION_INVOCATION_ALL" } } - ] + ] } }, "names": ["val", "cnt"] @@ -264,6 +264,16 @@ func TestAggregateRelErrors(t *testing.T) { _, err = b.AggregateExprsRemap(scan, []int32{5, -1}, nil, []expr.Expression{ref}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + + _, err = b.AggregateExprsRemap(scan, []int32{1}, nil, []expr.Expression{ref}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + + _, err = b.AggregateExprsRemap(scan, []int32{0}, nil, []expr.Expression{ref}) + assert.NoError(t, err) + _, err = b.AggregateColumnsRemap(scan, []int32{0}, nil, 0) + assert.NoError(t, err) + } func TestCrossRel(t *testing.T) { @@ -357,10 +367,14 @@ func TestCrossRelErrors(t *testing.T) { _, err = b.CrossRemap(left, right, []int32{5}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + + // Output is length 2 + 2 + _, err = b.CrossRemap(left, right, []int32{2, 3}) + assert.NoError(t, err) } func TestFetchRel(t *testing.T) { - const expectedJSON = `{ + const expectedJSON = `{ ` + versionStruct + `, "relations": [ { @@ -895,7 +909,7 @@ func TestSortRelationKeyEqual(t *testing.T) { func TestSortRelationMultiple(t *testing.T) { const expectedJSON = `{ - ` + versionStruct + `, + ` + versionStruct + `, "relations": [ { "root": { @@ -1135,6 +1149,9 @@ func TestProjectErrors(t *testing.T) { _, err = b.ProjectRemap(scan, []int32{3}, ref) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + + _, err = b.ProjectRemap(scan, []int32{2}, ref) + assert.NoError(t, err, "Expected expression mapping to be in-bounds") } func TestSetRelations(t *testing.T) {