Skip to content

Commit

Permalink
convert expand op to expand exec in velox
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Apr 14, 2023
1 parent ab94b64 commit 78159d0
Show file tree
Hide file tree
Showing 12 changed files with 511 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ object VeloxBackendSettings extends BackendSettings {

override def supportExpandExec(): Boolean = true
override def needProjectExpandOutput: Boolean = true

override def useExpandInsteadGroupId(): Boolean = true
override def supportSortExec(): Boolean = true

override def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
Expand Down
10 changes: 10 additions & 0 deletions cpp/velox/compute/VeloxBackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ void VeloxBackend::setInputPlanNode(const ::substrait::ExpandRel& sexpand) {
}
}

void VeloxBackend::setInputPlanNode(const ::substrait::GroupIdRel& sGroupId) {
if (sGroupId.has_input()) {
setInputPlanNode(sGroupId.input());
} else {
throw std::runtime_error("Child expected");
}
}

void VeloxBackend::setInputPlanNode(const ::substrait::SortRel& ssort) {
if (ssort.has_input()) {
setInputPlanNode(ssort.input());
Expand Down Expand Up @@ -205,6 +213,8 @@ void VeloxBackend::setInputPlanNode(const ::substrait::Rel& srel) {
setInputPlanNode(srel.sort());
} else if (srel.has_expand()) {
setInputPlanNode(srel.expand());
} else if (srel.has_group_id()) {
setInputPlanNode(srel.group_id());
} else if (srel.has_fetch()) {
setInputPlanNode(srel.fetch());
} else if (srel.has_window()) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/velox/compute/VeloxBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class VeloxBackend final : public Backend {

void setInputPlanNode(const ::substrait::ExpandRel& sExpand);

void setInputPlanNode(const ::substrait::GroupIdRel& sGroupId);

void setInputPlanNode(const ::substrait::SortRel& sSort);

void setInputPlanNode(const ::substrait::WindowRel& s);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,22 @@

public class ExpandRelNode implements RelNode, Serializable {
private final RelNode input;
private final String groupName;
private final ArrayList<ArrayList<ExpressionNode>> groupings = new ArrayList<>();

private final ArrayList<ExpressionNode> aggExpressions = new ArrayList<>();
private final ArrayList<ArrayList<ExpressionNode>> projections = new ArrayList<>();

private final AdvancedExtensionNode extensionNode;

public ExpandRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
AdvancedExtensionNode extensionNode) {
public ExpandRelNode(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections,
AdvancedExtensionNode extensionNode) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.projections.addAll(projections);
this.extensionNode = extensionNode;
}

public ExpandRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions) {
public ExpandRelNode(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.projections.addAll(projections);
this.extensionNode = null;
}

Expand All @@ -68,25 +59,19 @@ public Rel toProtobuf() {
expandBuilder.setInput(input.toProtobuf());
}

for (ArrayList<ExpressionNode> groupList: groupings) {
ExpandRel.GroupSets.Builder groupingBuilder =
ExpandRel.GroupSets.newBuilder();
for (ExpressionNode exprNode : groupList) {
groupingBuilder.addGroupSetsExpressions(exprNode.toProtobuf());
}
expandBuilder.addGroupings(groupingBuilder.build());
}

for (ExpressionNode aggExpression: aggExpressions) {
expandBuilder.addAggregateExpressions(aggExpression.toProtobuf());
}
for (ArrayList<ExpressionNode> projectList: projections) {
ExpandRel.ProjectionSets.Builder projectsBuilder =
ExpandRel.ProjectionSets.newBuilder();
for (ExpressionNode exprNode : projectList) {
projectsBuilder.addProjectionSetsExpressions(exprNode.toProtobuf());
}
expandBuilder.addProjections(projectsBuilder.build());
}

if (extensionNode != null) {
expandBuilder.setAdvancedExtension(extensionNode.toProtobuf());
}

expandBuilder.setGroupName(groupName);

Rel.Builder builder = Rel.newBuilder();
builder.setExpand(expandBuilder.build());
return builder.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.glutenproject.substrait.rel;

import io.glutenproject.substrait.expression.ExpressionNode;
import io.glutenproject.substrait.extensions.AdvancedExtensionNode;
import io.substrait.proto.GroupIdRel;
import io.substrait.proto.Rel;
import io.substrait.proto.RelCommon;

import java.io.Serializable;
import java.util.ArrayList;

public class GroupIdRelNode implements RelNode, Serializable {
private final RelNode input;
private final String groupName;
private final ArrayList<ArrayList<ExpressionNode>> groupings = new ArrayList<>();

private final ArrayList<ExpressionNode> aggExpressions = new ArrayList<>();

private final AdvancedExtensionNode extensionNode;

public GroupIdRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
AdvancedExtensionNode extensionNode) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.extensionNode = extensionNode;
}

public GroupIdRelNode(RelNode input, String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions) {
this.input = input;
this.groupName = groupName;
this.groupings.addAll(groupings);
this.aggExpressions.addAll(aggExpressions);
this.extensionNode = null;
}

@Override
public Rel toProtobuf() {
RelCommon.Builder relCommonBuilder = RelCommon.newBuilder();
relCommonBuilder.setDirect(RelCommon.Direct.newBuilder());

GroupIdRel.Builder groupIdBuilder = GroupIdRel.newBuilder();
groupIdBuilder.setCommon(relCommonBuilder.build());

if (input != null) {
groupIdBuilder.setInput(input.toProtobuf());
}

for (ArrayList<ExpressionNode> groupList: groupings) {
GroupIdRel.GroupSets.Builder groupingBuilder =
GroupIdRel.GroupSets.newBuilder();
for (ExpressionNode exprNode : groupList) {
groupingBuilder.addGroupSetsExpressions(exprNode.toProtobuf());
}
groupIdBuilder.addGroupings(groupingBuilder.build());
}

for (ExpressionNode aggExpression: aggExpressions) {
groupIdBuilder.addAggregateExpressions(aggExpression.toProtobuf());
}

if (extensionNode != null) {
groupIdBuilder.setAdvancedExtension(extensionNode.toProtobuf());
}

groupIdBuilder.setGroupName(groupName);

Rel.Builder builder = Rel.newBuilder();
builder.setGroupId(groupIdBuilder.build());
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,29 +198,46 @@ public static RelNode makeJoinRel(RelNode left,
return new JoinRelNode(left, right, joinType, expression, postJoinFilter, extensionNode);
}

public static RelNode makeExpandRel(RelNode input,
public static RelNode makeGroupIdRel(RelNode input,
String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
AdvancedExtensionNode extensionNode,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, groupName,
return new GroupIdRelNode(input, groupName,
groupings, aggExpressions, extensionNode);
}

public static RelNode makeExpandRel(RelNode input,
public static RelNode makeGroupIdRel(RelNode input,
String groupName,
ArrayList<ArrayList<ExpressionNode>> groupings,
ArrayList<ExpressionNode> aggExpressions,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, groupName,
return new GroupIdRelNode(input, groupName,
groupings, aggExpressions);
}

public static RelNode makeExpandRel(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections,
AdvancedExtensionNode extensionNode,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, projections, extensionNode);
}

public static RelNode makeExpandRel(RelNode input,
ArrayList<ArrayList<ExpressionNode>> projections,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ExpandRelNode(input, projections);
}

public static RelNode makeSortRel(RelNode input,
ArrayList<SortField> sorts,
AdvancedExtensionNode extensionNode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,32 @@ message ExchangeRel {
}
}

message GroupIdRel {
RelCommon common = 1;
Rel input = 2;

repeated Expression aggregate_expressions = 3;

// A list of expression grouping that the aggregation measured should be calculated for.
repeated GroupSets groupings = 4;

message GroupSets {
repeated Expression groupSets_expressions = 1;
}

string group_name = 5;

substrait.extensions.AdvancedExtension advanced_extension = 10;
}

message ExpandRel {
RelCommon common = 1;
Rel input = 2;
.
repeated Projections projections = 3;

message Projections {
repeated Expression project_expressions = 1;
repeated ProjectionSets projections = 3;

message ProjectionSets {
repeated Expression projectionSets_expressions = 1;
}

substrait.extensions.AdvancedExtension advanced_extension = 10;
Expand Down Expand Up @@ -396,9 +414,10 @@ message Rel {
//Physical relations
HashJoinRel hash_join = 13;
MergeJoinRel merge_join = 14;
ExpandRel expand = 15;
GroupIdRel group_id = 15;
WindowRel window = 16;
GenerateRel generate = 17;
ExpandRel expand = 18;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ trait BackendSettings {
paths: Seq[String]): Boolean = false
def supportExpandExec(): Boolean = false
def needProjectExpandOutput: Boolean = false
def useExpandInsteadGroupId(): Boolean = false
def supportSortExec(): Boolean = false
def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
false
Expand Down
Loading

0 comments on commit 78159d0

Please sign in to comment.