-
Notifications
You must be signed in to change notification settings - Fork 240
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support running CPU based UDF efficiently [databricks] (#3897)
* Support running CPU based UDF. Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Address the comments Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Add more tests Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Add log for the time of running the UDF. along with some small refactors. Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Add comments for disabling the input encoder Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Comment update Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Fix a build error on DB Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Address the new comments Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Update comments Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala Update the config doc Co-authored-by: Jason Lowe <jlowe@nvidia.com> * Use the code of 3.0 for input type conversion Signed-off-by: Firestarman <firestarmanllc@gmail.com> * comment update Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Doc update Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Address the new comments Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Add shims for GpuRowBasedScalaUDF Signed-off-by: Firestarman <firestarmanllc@gmail.com> * Update the test file name Signed-off-by: Firestarman <firestarmanllc@gmail.com> Co-authored-by: Jason Lowe <jlowe@nvidia.com>
- Loading branch information
1 parent
bfaf73b
commit 96ca730
Showing
17 changed files
with
963 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
...src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.shims.v2 | ||
|
||
import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuOverrides, GpuUserDefinedFunction, RepeatingParamCheck, TypeSig} | ||
|
||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | ||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} | ||
import org.apache.spark.sql.rapids.{GpuRowBasedScalaUDFBase, ScalaUDFMetaBase} | ||
import org.apache.spark.sql.types.DataType | ||
|
||
/** Run a row-based UDF in a GPU operation */ | ||
case class GpuRowBasedScalaUDF( | ||
sparkFunc: AnyRef, | ||
dataType: DataType, | ||
children: Seq[Expression], | ||
inputEncoders: Seq[Option[ExpressionEncoder[_]]], | ||
udfName: Option[String], | ||
nullable: Boolean, | ||
udfDeterministic: Boolean) | ||
extends GpuRowBasedScalaUDFBase(sparkFunc, dataType, children, inputEncoders, None, udfName) { | ||
|
||
override def createInputConverter(i: Int, dataType: DataType): Any => Any = { | ||
if (inputEncoders.isEmpty) { | ||
// for untyped Scala UDF | ||
CatalystTypeConverters.createToScalaConverter(dataType) | ||
} else { | ||
val encoder = inputEncoders(i) | ||
if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) { | ||
val fromRow = encoder.get.resolveAndBind().createDeserializer() | ||
row: Any => fromRow(row.asInstanceOf[InternalRow]) | ||
} else { | ||
CatalystTypeConverters.createToScalaConverter(dataType) | ||
} | ||
} | ||
} | ||
} | ||
|
||
object GpuScalaUDFMeta { | ||
def exprMeta: ExprRule[ScalaUDF] = GpuOverrides.expr[ScalaUDF]( | ||
"User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface " + | ||
"to get better performance.", | ||
ExprChecks.projectOnly( | ||
GpuUserDefinedFunction.udfTypeSig, | ||
TypeSig.all, | ||
repeatingParamCheck = | ||
Some(RepeatingParamCheck("param", GpuUserDefinedFunction.udfTypeSig, TypeSig.all))), | ||
(expr, conf, p, r) => new ScalaUDFMetaBase(expr, conf, p, r) { | ||
override protected def rowBasedScalaUDF: GpuRowBasedScalaUDFBase = | ||
GpuRowBasedScalaUDF( | ||
expr.function, | ||
expr.dataType, | ||
childExprs.map(_.convertToGpu()), | ||
expr.inputEncoders, | ||
expr.udfName, | ||
expr.nullable, | ||
expr.udfDeterministic) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
...plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.shims.v2 | ||
|
||
import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuOverrides, GpuUserDefinedFunction, RepeatingParamCheck, TypeSig} | ||
|
||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | ||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, ScalaUDF} | ||
import org.apache.spark.sql.rapids.{GpuRowBasedScalaUDFBase, ScalaUDFMetaBase} | ||
import org.apache.spark.sql.types.DataType | ||
|
||
/** Run a row-based UDF in a GPU operation */ | ||
case class GpuRowBasedScalaUDF( | ||
sparkFunc: AnyRef, | ||
dataType: DataType, | ||
children: Seq[Expression], | ||
inputEncoders: Seq[Option[ExpressionEncoder[_]]], | ||
outputEncoder: Option[ExpressionEncoder[_]], | ||
udfName: Option[String], | ||
nullable: Boolean, | ||
udfDeterministic: Boolean) | ||
extends GpuRowBasedScalaUDFBase(sparkFunc, dataType, children, inputEncoders, outputEncoder, | ||
udfName) { | ||
|
||
override def createInputConverter(i: Int, dataType: DataType): Any => Any = | ||
scalaConverter(i, dataType)._1 | ||
|
||
/** | ||
* Create the converter which converts the catalyst data type to the scala data type. | ||
* We use `CatalystTypeConverters` to create the converter for: | ||
* - UDF which doesn't provide inputEncoders, e.g., untyped Scala UDF and Java UDF | ||
* - type which isn't supported by `ExpressionEncoder`, e.g., Any | ||
* - primitive types, in order to use `identity` for better performance | ||
* For other cases like case class, Option[T], we use `ExpressionEncoder` instead since | ||
* `CatalystTypeConverters` doesn't support these data types. | ||
* | ||
* @param i the index of the child | ||
* @param dataType the output data type of the i-th child | ||
* @return the converter and a boolean value to indicate whether the converter is | ||
* created by using `ExpressionEncoder`. | ||
*/ | ||
private def scalaConverter(i: Int, dataType: DataType): (Any => Any, Boolean) = { | ||
val useEncoder = | ||
!(inputEncoders.isEmpty || // for untyped Scala UDF and Java UDF | ||
inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any | ||
inputPrimitives(i)) // for primitive types | ||
|
||
if (useEncoder) { | ||
val enc = inputEncoders(i).get | ||
val fromRow = enc.createDeserializer() | ||
val converter = if (enc.isSerializedAsStructForTopLevel) { | ||
row: Any => fromRow(row.asInstanceOf[InternalRow]) | ||
} else { | ||
val inputRow = new GenericInternalRow(1) | ||
value: Any => inputRow.update(0, value); fromRow(inputRow) | ||
} | ||
(converter, true) | ||
} else { // use CatalystTypeConverters | ||
(CatalystTypeConverters.createToScalaConverter(dataType), false) | ||
} | ||
} | ||
} | ||
|
||
object GpuScalaUDFMeta { | ||
def exprMeta: ExprRule[ScalaUDF] = GpuOverrides.expr[ScalaUDF]( | ||
"User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface " + | ||
"to get better performance.", | ||
ExprChecks.projectOnly( | ||
GpuUserDefinedFunction.udfTypeSig, | ||
TypeSig.all, | ||
repeatingParamCheck = | ||
Some(RepeatingParamCheck("param", GpuUserDefinedFunction.udfTypeSig, TypeSig.all))), | ||
(expr, conf, p, r) => new ScalaUDFMetaBase(expr, conf, p, r) { | ||
override protected def rowBasedScalaUDF: GpuRowBasedScalaUDFBase = | ||
GpuRowBasedScalaUDF( | ||
expr.function, | ||
expr.dataType, | ||
childExprs.map(_.convertToGpu()), | ||
expr.inputEncoders, | ||
expr.outputEncoder, | ||
expr.udfName, | ||
expr.nullable, | ||
expr.udfDeterministic) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.