Skip to content

Commit

Permalink
[GLUTEN-842][VL] convert expand op to expand exec in velox (apache#1361)
Browse files Browse the repository at this point in the history
* init change

* convert expand op to expand exec in velox

* add pre-project & add ut

* minor change

* fix ut

* update algebra.proto

* fix build

* fix build

* fix build

* add ut

* revert velox branch

---------

Co-authored-by: zhli1142015 <zhli@pczhlich.fareast.corp.microsoft.com>
  • Loading branch information
zhli1142015 and zhli1142015 authored Apr 19, 2023
1 parent 8950c95 commit 9a25d75
Show file tree
Hide file tree
Showing 15 changed files with 613 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ object VeloxBackendSettings extends BackendSettings {

override def supportExpandExec(): Boolean = true
override def needProjectExpandOutput: Boolean = true

override def supportNewExpandContract(): Boolean = true
override def supportSortExec(): Boolean = true

override def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
Expand Down
6 changes: 3 additions & 3 deletions cpp-ch/local-engine/Parser/ExpandRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ExpandRelParser::ExpandRelParser(SerializedPlanParser * plan_parser_)
DB::QueryPlanPtr
ExpandRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel*> & rel_stack)
{
const auto & expand_rel = rel.expand();
const auto & expand_rel = rel.group_id();
std::vector<size_t> aggregating_expressions_columns;
std::set<size_t> agg_cols_ref;
const auto & header = query_plan->getCurrentDataStream().header;
Expand Down Expand Up @@ -59,7 +59,7 @@ ExpandRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel,
}


void ExpandRelParser::buildGroupingSets(const substrait::ExpandRel & expand_rel, std::vector<std::set<size_t>> & grouping_sets)
void ExpandRelParser::buildGroupingSets(const substrait::GroupIdRel & expand_rel, std::vector<std::set<size_t>> & grouping_sets)
{
for (int i = 0; i < expand_rel.groupings_size(); ++i)
{
Expand Down Expand Up @@ -87,6 +87,6 @@ void registerExpandRelParser(RelParserFactory & factory)
{
return std::make_shared<ExpandRelParser>(plan_parser);
};
factory.registerBuilder(substrait::Rel::RelTypeCase::kExpand, builder);
factory.registerBuilder(substrait::Rel::RelTypeCase::kGroupId, builder);
}
}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/ExpandRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ class ExpandRelParser : public RelParser
DB::QueryPlanPtr
parse(DB::QueryPlanPtr query_plan, const substrait::Rel & sort_rel, std::list<const substrait::Rel *> & rel_stack_) override;
private:
static void buildGroupingSets(const substrait::ExpandRel & expand_rel, std::vector<std::set<size_t>> & grouping_sets);
static void buildGroupingSets(const substrait::GroupIdRel & expand_rel, std::vector<std::set<size_t>> & grouping_sets);
};
}
6 changes: 3 additions & 3 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,12 +894,12 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list
query_plan = win_parser->parse(std::move(query_plan), rel, rel_stack);
break;
}
case substrait::Rel::RelTypeCase::kExpand: {
case substrait::Rel::RelTypeCase::kGroupId: {
rel_stack.push_back(&rel);
const auto & expand_rel = rel.expand();
const auto & expand_rel = rel.group_id();
query_plan = parseOp(expand_rel.input(), rel_stack);
rel_stack.pop_back();
auto epand_parser = RelParserFactory::instance().getBuilder(substrait::Rel::RelTypeCase::kExpand)(this);
auto epand_parser = RelParserFactory::instance().getBuilder(substrait::Rel::RelTypeCase::kGroupId)(this);
query_plan = epand_parser->parse(std::move(query_plan), rel, rel_stack);
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,22 @@

public class ExpandRelNode implements RelNode, Serializable {
private final RelNode input;
private final String groupName;
private final ArrayList<ArrayList<ExpressionNode>> groupings = new ArrayList<>();

private final ArrayList<ExpressionNode> aggExpressions = new ArrayList<>();
private final ArrayList<ArrayList<ExpressionNode>> projections = new ArrayList<>();

private final AdvancedExtensionNode extensionNode;

public ExpandRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
AdvancedExtensionNode extensionNode) {
public ExpandRelNode(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections,
AdvancedExtensionNode extensionNode) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.projections.addAll(projections);
this.extensionNode = extensionNode;
}

public ExpandRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions) {
public ExpandRelNode(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.projections.addAll(projections);
this.extensionNode = null;
}

Expand All @@ -68,25 +59,22 @@ public Rel toProtobuf() {
expandBuilder.setInput(input.toProtobuf());
}

for (ArrayList<ExpressionNode> groupList: groupings) {
ExpandRel.GroupSets.Builder groupingBuilder =
ExpandRel.GroupSets.newBuilder();
for (ExpressionNode exprNode : groupList) {
groupingBuilder.addGroupSetsExpressions(exprNode.toProtobuf());
}
expandBuilder.addGroupings(groupingBuilder.build());
}

for (ExpressionNode aggExpression: aggExpressions) {
expandBuilder.addAggregateExpressions(aggExpression.toProtobuf());
}
for (ArrayList<ExpressionNode> projectList: projections) {
ExpandRel.ExpandField.Builder expandFieldBuilder =
ExpandRel.ExpandField.newBuilder();
ExpandRel.SwitchingField.Builder switchingField =
ExpandRel.SwitchingField.newBuilder();
for (ExpressionNode exprNode : projectList) {
switchingField.addDuplicates(exprNode.toProtobuf());
}
expandFieldBuilder.setSwitchingField(switchingField.build());
expandBuilder.addFields(expandFieldBuilder.build());
}

if (extensionNode != null) {
expandBuilder.setAdvancedExtension(extensionNode.toProtobuf());
}

expandBuilder.setGroupName(groupName);

Rel.Builder builder = Rel.newBuilder();
builder.setExpand(expandBuilder.build());
return builder.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.glutenproject.substrait.rel;

import io.glutenproject.substrait.expression.ExpressionNode;
import io.glutenproject.substrait.extensions.AdvancedExtensionNode;
import io.substrait.proto.GroupIdRel;
import io.substrait.proto.Rel;
import io.substrait.proto.RelCommon;

import java.io.Serializable;
import java.util.ArrayList;

public class GroupIdRelNode implements RelNode, Serializable {
private final RelNode input;
private final String groupName;
private final ArrayList<ArrayList<ExpressionNode>> groupings = new ArrayList<>();

private final ArrayList<ExpressionNode> aggExpressions = new ArrayList<>();

private final AdvancedExtensionNode extensionNode;

public GroupIdRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
AdvancedExtensionNode extensionNode) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.extensionNode = extensionNode;
}

public GroupIdRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.extensionNode = null;
}

@Override
public Rel toProtobuf() {
RelCommon.Builder relCommonBuilder = RelCommon.newBuilder();
relCommonBuilder.setDirect(RelCommon.Direct.newBuilder());

GroupIdRel.Builder groupIdBuilder = GroupIdRel.newBuilder();
groupIdBuilder.setCommon(relCommonBuilder.build());

if (input != null) {
groupIdBuilder.setInput(input.toProtobuf());
}

for (ArrayList<ExpressionNode> groupList: groupings) {
GroupIdRel.GroupSets.Builder groupingBuilder =
GroupIdRel.GroupSets.newBuilder();
for (ExpressionNode exprNode : groupList) {
groupingBuilder.addGroupSetsExpressions(exprNode.toProtobuf());
}
groupIdBuilder.addGroupings(groupingBuilder.build());
}

for (ExpressionNode aggExpression: aggExpressions) {
groupIdBuilder.addAggregateExpressions(aggExpression.toProtobuf());
}

if (extensionNode != null) {
groupIdBuilder.setAdvancedExtension(extensionNode.toProtobuf());
}

groupIdBuilder.setGroupName(groupName);

Rel.Builder builder = Rel.newBuilder();
builder.setGroupId(groupIdBuilder.build());
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,29 +198,46 @@ public static RelNode makeJoinRel(RelNode left,
return new JoinRelNode(left, right, joinType, expression, postJoinFilter, extensionNode);
}

public static RelNode makeExpandRel(RelNode input,
public static RelNode makeGroupIdRel(RelNode input,
String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
AdvancedExtensionNode extensionNode,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, groupName,
return new GroupIdRelNode(input, groupName,
groupings, aggExpressions, extensionNode);
}

public static RelNode makeExpandRel(RelNode input,
public static RelNode makeGroupIdRel(RelNode input,
String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, groupName,
return new GroupIdRelNode(input, groupName,
groupings, aggExpressions);
}

public static RelNode makeExpandRel(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections,
AdvancedExtensionNode extensionNode,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, projections, extensionNode);
}

public static RelNode makeExpandRel(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, projections);
}

public static RelNode makeSortRel(RelNode input,
ArrayList<SortField> sorts,
AdvancedExtensionNode extensionNode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,38 @@ message ExchangeRel {
}
}

// Duplicates records, possibly switching output expressions between each duplicate.
// Default output is all of the fields declared followed by one int64 field that contains the
// duplicate_id which is a zero-index ordinal of which duplicate of the original record this
// corresponds to.
message ExpandRel {
RelCommon common = 1;
Rel input = 2;
repeated ExpandField fields = 4;
substrait.extensions.AdvancedExtension advanced_extension = 10;

message ExpandField {
oneof field_type {
// Field that switches output based on which duplicate_id we're outputting
SwitchingField switching_field = 2;

// Field that outputs the same value no matter which duplicate_id we're on.
Expression consistent_field = 3;
}
}

message SwitchingField {
// Array that contains an expression to output per duplicate_id
// each `switching_field` must have the same number of expressions
// all expressions within a switching field be the same type class but can differ in nullability.
// this column will be nullable if any of the expressions are nullable.
repeated Expression duplicates = 1;
}
}

message GroupIdRel {
RelCommon common = 1;
Rel input = 2;

repeated Expression aggregate_expressions = 3;

Expand Down Expand Up @@ -404,6 +433,7 @@ message Rel {
ExpandRel expand = 15;
WindowRel window = 16;
GenerateRel generate = 17;
GroupIdRel group_id = 18;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ trait BackendSettings {
paths: Seq[String]): Boolean = false
def supportExpandExec(): Boolean = false
def needProjectExpandOutput: Boolean = false
def supportNewExpandContract(): Boolean = false
def supportSortExec(): Boolean = false
def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
false
Expand Down
Loading

0 comments on commit 9a25d75

Please sign in to comment.