Skip to content

Commit

Permalink
[VL] Support inline function (#4847)
Browse files Browse the repository at this point in the history
As explode functions is already supported, we can also enable inline function by adding a post-project to get the inner fields from the struct.
  • Loading branch information
marin-ma authored Mar 7, 2024
1 parent f9daf49 commit b7a6329
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(child), expr)
}

/** Transform inline to Substrait. */
override def genInlineTransformer(
substraitExprName: String,
child: ExpressionTransformer,
expr: Expression): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(child), expr)
}

/**
* * Plans.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,30 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
}
}

test("test inline function") {

withTempView("t1") {
sql("""select * from values
| array(
| named_struct('c1', 0, 'c2', 1),
| null,
| named_struct('c1', 2, 'c2', 3)
| ),
| array(
| null,
| named_struct('c1', 0, 'c2', 1),
| named_struct('c1', 2, 'c2', 3)
| )
|as tbl(a)
""".stripMargin).createOrReplaceTempView("t1")
runQueryAndCompare("""
|SELECT inline(a) from t1;
|""".stripMargin) {
checkOperatorMatch[GenerateExecTransformer]
}
}
}

test("test array functions") {
withTable("t") {
sql("CREATE TABLE t (c1 ARRAY<INT>, c2 ARRAY<INT>, c3 STRING) using parquet")
Expand Down
1 change: 0 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxExpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,6 @@ core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr(
core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr(
const ::substrait::Expression& substraitExpr,
const RowTypePtr& inputType) {
core::TypedExprPtr veloxExpr;
auto typeCase = substraitExpr.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kLiteral:
Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
replicated.reserve(requiredChildOutput.size());
for (const auto& output : requiredChildOutput) {
auto expression = exprConverter_->toVeloxExpr(output, inputType);
auto expr_field = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(expr_field != nullptr, " the output in Generate Operator only support field")
auto exprField = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(exprField != nullptr, " the output in Generate Operator only support field")

replicated.emplace_back(std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(expression));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ trait SparkPlanExecApi {
throw new UnsupportedOperationException("map_entries is not supported")
}

/** Transform inline to Substrait. */
def genInlineTransformer(
substraitExprName: String,
child: ExpressionTransformer,
expr: Expression): ExpressionTransformer = {
throw new UnsupportedOperationException("map_entries is not supported")
}

/**
* Generate ShuffleDependency for ColumnarShuffleExchangeExec.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ case class GenerateExecTransformer(
.asJava
projectExpressions.addAll(childOutputNodes)
val projectExprNode = ExpressionConverter
.replaceWithExpressionTransformer(generator.asInstanceOf[Explode].child, child.output)
.replaceWithExpressionTransformer(
generator.asInstanceOf[UnaryExpression].child,
child.output)
.doTransform(args)

projectExpressions.add(projectExprNode)
Expand Down Expand Up @@ -152,19 +154,67 @@ case class GenerateExecTransformer(
operatorId: Long,
inputAttributes: Seq[Attribute],
input: RelNode,
generator: ExpressionNode,
generatorNode: ExpressionNode,
childOutput: JList[ExpressionNode],
validation: Boolean): RelNode = {
if (!validation) {
RelBuilder.makeGenerateRel(input, generator, childOutput, context, operatorId)
val generateRel = if (!validation) {
RelBuilder.makeGenerateRel(input, generatorNode, childOutput, context, operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList =
inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeGenerateRel(input, generator, childOutput, extensionNode, context, operatorId)
RelBuilder.makeGenerateRel(
input,
generatorNode,
childOutput,
extensionNode,
context,
operatorId)
}
applyPostProjectOnGenerator(generateRel, context, operatorId, childOutput, validation)
}

// There are 3 types of CollectionGenerator in spark: Explode, PosExplode and Inline.
// Only Inline needs the post projection.
private def applyPostProjectOnGenerator(
generateRel: RelNode,
context: SubstraitContext,
operatorId: Long,
childOutput: JList[ExpressionNode],
validation: Boolean): RelNode = {
generator match {
case Inline(inlineChild) =>
inlineChild match {
case _: AttributeReference =>
case _ =>
throw new UnsupportedOperationException("Child of Inline is not AttributeReference.")
}
val requiredOutput = (0 until childOutput.size).map {
ExpressionBuilder.makeSelection(_)
}
val flattenStruct: Seq[ExpressionNode] = generatorOutput.indices.map {
i =>
val selectionNode = ExpressionBuilder.makeSelection(requiredOutput.size)
selectionNode.addNestedChildIdx(i)
}
val postProjectRel = RelBuilder.makeProjectRel(
generateRel,
(requiredOutput ++ flattenStruct).asJava,
context,
operatorId,
1 + requiredOutput.size // 1 stands for the inner struct field from array.
)
if (validation) {
// No need to validate the project rel on the native side as
// it only flattens the generator's output.
generateRel
} else {
postProjectRel
}
case _ => generateRel
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformerInternal(p.child, attributeSeq, expressionsMap),
p,
attributeSeq)
case i: Inline =>
BackendsApiManager.getSparkPlanExecApiInstance.genInlineTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(i.child, attributeSeq, expressionsMap),
i)
case a: Alias =>
BackendsApiManager.getSparkPlanExecApiInstance.genAliasTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ object ExpressionMappings {
Sig[Sequence](SEQUENCE),
Sig[CreateArray](CREATE_ARRAY),
Sig[Explode](EXPLODE),
Sig[Inline](INLINE),
Sig[ArrayAggregate](AGGREGATE),
Sig[LambdaFunction](LAMBDAFUNCTION),
Sig[NamedLambdaVariable](NAMED_LAMBDA_VARIABLE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ object ExpressionNames {
final val AGGREGATE = "aggregate"
final val LAMBDAFUNCTION = "lambdafunction"
final val EXPLODE = "explode"
final val INLINE = "inline"
final val POSEXPLODE = "posexplode"
final val CHECK_OVERFLOW = "check_overflow"
final val MAKE_DECIMAL = "make_decimal"
Expand Down

0 comments on commit b7a6329

Please sign in to comment.