Skip to content

Commit

Permalink
Accelerate RunningWindow queries on GPU (#2600)
Browse files Browse the repository at this point in the history
* Tentative fix

Signed-off-by: Mithun RK <mythrocks@gmail.com>

* Delegate recognition of window-function exec to Shim.

* Tentative support for Databricks 7.x

* Fix formatting.

* Moved commonality to base class for Gpu window execs.

* Delegate to GpuBaseWindowExecMeta, from Databricks Shims.

* Renamed test for RunningWindowFunctionExec.

* Updated copyright date.

* Fix copyright to include full year.

* Switch from explicit cast to match() condition.
  • Loading branch information
mythrocks authored Jun 9, 2021
1 parent 7d91ae2 commit 2b75173
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 32 deletions.
31 changes: 31 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,34 @@ def test_window_aggs_for_rows_collect_list():
(partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_struct
from window_collect_table
''')


# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
def test_running_window_function_exec_for_all_aggs():
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, _gen_data_for_collect),
"window_collect_table",
'''
select
sum(c_int) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_int,
min(c_long) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as min_long,
max(c_time) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as max_time,
count(1) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as count_1,
count(*) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as count_star,
row_number() over
(partition by a order by b,c_int) as row_num,
collect_list(c_float) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_float,
collect_list(c_decimal) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_decimal,
collect_list(c_struct) over
(partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_struct
from window_collect_table
''')

Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, MapInPandasExec, WindowInPandasExec}
import org.apache.spark.sql.execution.window.WindowExecBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuStringReplace, GpuTimeSub, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
Expand Down Expand Up @@ -141,6 +142,8 @@ abstract class SparkBaseShims extends SparkShims {
}
}

override def isWindowFunctionExec(plan: SparkPlan): Boolean = plan.isInstanceOf[WindowExecBase]

override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuShuffledHashJoinExec => true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.spark301db

import com.databricks.sql.execution.window.RunningWindowFunctionExec
import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, GpuBaseWindowExecMeta, GpuExec, GpuOverrides, GpuWindowExec, RapidsConf, RapidsMeta, SparkPlanMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, SortOrder}

/**
* GPU-based window-exec implementation, analogous to RunningWindowFunctionExec.
*/
class GpuRunningWindowExecMeta(runningWindowFunctionExec: RunningWindowFunctionExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends GpuBaseWindowExecMeta[RunningWindowFunctionExec](runningWindowFunctionExec, conf, parent, rule) {

override def getInputWindowExpressions: Seq[NamedExpression] = runningWindowFunctionExec.windowExpressionList
override def getPartitionSpecs: Seq[Expression] = runningWindowFunctionExec.partitionSpec
override def getOrderSpecs: Seq[SortOrder] = runningWindowFunctionExec.orderSpec
override def getResultColumnsOnly: Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.shims.spark301db

import com.databricks.sql.execution.window.RunningWindowFunctionExec
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark301.Spark301Shims
import org.apache.hadoop.fs.Path
Expand All @@ -36,6 +37,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, MapInPandasExec, WindowInPandasExec}
import org.apache.spark.sql.execution.window.WindowExecBase
import org.apache.spark.sql.rapids.GpuFileSourceScanExec
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.rapids.execution.python.{GpuAggregateInPandasExecMeta, GpuArrowEvalPythonExec, GpuFlatMapGroupsInPandasExecMeta, GpuMapInPandasExecMeta, GpuPythonUDF, GpuWindowInPandasExecMetaBase}
Expand Down Expand Up @@ -75,6 +77,9 @@ class Spark301dbShims extends Spark301Shims {
}
}

override def isWindowFunctionExec(plan: SparkPlan): Boolean =
plan.isInstanceOf[WindowExecBase] || plan.isInstanceOf[RunningWindowFunctionExec]

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
Seq(
GpuOverrides.exec[WindowInPandasExec](
Expand All @@ -96,6 +101,17 @@ class Spark301dbShims extends Spark301Shims {
)
}
}).disabledByDefault("it only supports row based frame for now"),
GpuOverrides.exec[RunningWindowFunctionExec](
"Databricks-specific window function exec, for \"running\" windows, " +
"i.e. (UNBOUNDED PRECEDING TO CURRENT ROW)",
ExecChecks(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL) +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT
+ TypeSig.ARRAY),
TypeSig.all),
(runningWindowFunctionExec, conf, p, r) => new GpuRunningWindowExecMeta(runningWindowFunctionExec, conf, p, r)
),
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.spark311db

import com.databricks.sql.execution.window.RunningWindowFunctionExec
import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, GpuBaseWindowExecMeta, GpuExec, GpuOverrides, GpuWindowExec, RapidsConf, RapidsMeta, SparkPlanMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, SortOrder}

/**
* GPU-based window-exec implementation, analogous to RunningWindowFunctionExec.
*/
class GpuRunningWindowExecMeta(runningWindowFunctionExec: RunningWindowFunctionExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends GpuBaseWindowExecMeta[RunningWindowFunctionExec](runningWindowFunctionExec, conf, parent, rule) {

override def getInputWindowExpressions: Seq[NamedExpression] = runningWindowFunctionExec.windowExpressionList
override def getPartitionSpecs: Seq[Expression] = runningWindowFunctionExec.partitionSpec
override def getOrderSpecs: Seq[SortOrder] = runningWindowFunctionExec.orderSpec
override def getResultColumnsOnly: Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.shims.spark311db

import com.databricks.sql.execution.window.RunningWindowFunctionExec
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark311.Spark311Shims
import org.apache.hadoop.fs.Path
Expand All @@ -37,6 +38,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, MapInPandasExec, WindowInPandasExec}
import org.apache.spark.sql.execution.window.WindowExecBase
import org.apache.spark.sql.rapids.GpuFileSourceScanExec
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.rapids.execution.python.{GpuPythonUDF, GpuWindowInPandasExecMetaBase}
Expand Down Expand Up @@ -77,6 +79,9 @@ class Spark311dbShims extends Spark311Shims {
}
}

override def isWindowFunctionExec(plan: SparkPlan): Boolean =
plan.isInstanceOf[WindowExecBase] || plan.isInstanceOf[RunningWindowFunctionExec]

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
Seq(
GpuOverrides.exec[WindowInPandasExec](
Expand All @@ -99,6 +104,17 @@ class Spark311dbShims extends Spark311Shims {
)
}
}).disabledByDefault("it only supports row based frame for now"),
GpuOverrides.exec[RunningWindowFunctionExec](
"Databricks-specific window function exec, for \"running\" windows, " +
"i.e. (UNBOUNDED PRECEDING TO CURRENT ROW)",
ExecChecks(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL) +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT
+ TypeSig.ARRAY),
TypeSig.all),
(runningWindowFunctionExec, conf, p, r) => new GpuRunningWindowExecMeta(runningWindowFunctionExec, conf, p, r)
),
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
Expand Down
110 changes: 79 additions & 31 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,11 +24,82 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuWindowExecMeta(windowExec: WindowExec,
/**
* Base class for GPU Execs that implement window functions. This abstracts the method
* by which the window function's input expressions, partition specs, order-by specs, etc.
* are extracted from the specific WindowExecType.
*
* @tparam WindowExecType The Exec class that implements window functions
* (E.g. o.a.s.sql.execution.window.WindowExec.)
*/
abstract class GpuBaseWindowExecMeta[WindowExecType <: SparkPlan] (windowExec: WindowExecType,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends SparkPlanMeta[WindowExec](windowExec, conf, parent, rule) {
extends SparkPlanMeta[WindowExecType](windowExec, conf, parent, rule) {

/**
* Extracts window-expression from WindowExecType.
* The implementation varies, depending on the WindowExecType class.
*/
def getInputWindowExpressions: Seq[NamedExpression]

/**
* Extracts partition-spec from WindowExecType.
* The implementation varies, depending on the WindowExecType class.
*/
def getPartitionSpecs: Seq[Expression]

/**
* Extracts order-by spec from WindowExecType.
* The implementation varies, depending on the WindowExecType class.
*/
def getOrderSpecs: Seq[SortOrder]

/**
* Indicates the output column semantics for the WindowExecType,
* i.e. whether to only return the window-expression result columns (as in some Spark
* distributions) or also include the input columns (as in Apache Spark).
*/
def getResultColumnsOnly: Boolean

val windowExpressions: Seq[BaseExprMeta[NamedExpression]] =
getInputWindowExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
val partitionSpec: Seq[BaseExprMeta[Expression]] =
getPartitionSpecs.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
val orderSpec: Seq[BaseExprMeta[SortOrder]] =
getOrderSpecs.map(GpuOverrides.wrapExpr(_, conf, Some(this)))

override def tagPlanForGpu(): Unit = {
// Implementation depends on receiving a `NamedExpression` wrapped WindowExpression.
windowExpressions.map(meta => meta.wrapped)
.filter(expr => !expr.isInstanceOf[NamedExpression])
.foreach(_ => willNotWorkOnGpu(because = "Unexpected query plan with Windowing functions; " +
"cannot convert for GPU execution. " +
"(Detail: WindowExpression not wrapped in `NamedExpression`.)"))
}

override def convertToGpu(): GpuExec = {
GpuWindowExec(
windowExpressions.map(_.convertToGpu()),
partitionSpec.map(_.convertToGpu()),
orderSpec.map(_.convertToGpu().asInstanceOf[SortOrder]),
childPlans.head.convertIfNeeded(),
getResultColumnsOnly
)
}
}

/**
* Specialization of GpuBaseWindowExecMeta for org.apache.spark.sql.window.WindowExec.
* This class implements methods to extract the window-expressions, partition columns,
* order-by columns, etc. from WindowExec.
*/
class GpuWindowExecMeta(windowExec: WindowExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends GpuBaseWindowExecMeta[WindowExec](windowExec, conf, parent, rule) {

/**
* Fetches WindowExpressions in input `windowExec`, via reflection.
Expand Down Expand Up @@ -58,35 +129,12 @@ class GpuWindowExecMeta(windowExec: WindowExec,
(expr, resultColumnsOnly)
}

private val (inputWindowExpressions, resultColumnsOnly) = getWindowExpression

val windowExpressions: Seq[BaseExprMeta[NamedExpression]] =
inputWindowExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this)))

val partitionSpec: Seq[BaseExprMeta[Expression]] =
windowExec.partitionSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
val orderSpec: Seq[BaseExprMeta[SortOrder]] =
windowExec.orderSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this)))

override def tagPlanForGpu(): Unit = {
// Implementation depends on receiving a `NamedExpression` wrapped WindowExpression.
windowExpressions.map(meta => meta.wrapped)
.filter(expr => !expr.isInstanceOf[NamedExpression])
.foreach(_ => willNotWorkOnGpu(because = "Unexpected query plan with Windowing functions; " +
"cannot convert for GPU execution. " +
"(Detail: WindowExpression not wrapped in `NamedExpression`.)"))

}
private lazy val (inputWindowExpressions, resultColumnsOnly) = getWindowExpression

override def convertToGpu(): GpuExec = {
GpuWindowExec(
windowExpressions.map(_.convertToGpu()),
partitionSpec.map(_.convertToGpu()),
orderSpec.map(_.convertToGpu().asInstanceOf[SortOrder]),
childPlans.head.convertIfNeeded(),
resultColumnsOnly
)
}
override def getInputWindowExpressions: Seq[NamedExpression] = inputWindowExpressions
override def getPartitionSpecs: Seq[Expression] = windowExec.partitionSpec
override def getOrderSpecs: Seq[SortOrder] = windowExec.orderSpec
override def getResultColumnsOnly: Boolean = resultColumnsOnly
}

case class GpuWindowExec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ object ExpressionContext {
val parent = findParentPlanMeta(meta)
assert(parent.isDefined, "It is expected that an aggregate function is a child of a SparkPlan")
parent.get.wrapped match {
case _: WindowExecBase => WindowAggExprContext
case agg: SparkPlan if ShimLoader.getSparkShims.isWindowFunctionExec(agg) =>
WindowAggExprContext
case agg: BaseAggregateExec =>
if (agg.groupingExpressions.isEmpty) {
ReductionAggExprContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ trait SparkShims {

def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean
def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
def isWindowFunctionExec(plan: SparkPlan): Boolean
def getRapidsShuffleManagerClass: String
def getBuildSide(join: HashJoin): GpuBuildSide
def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide
Expand Down

0 comments on commit 2b75173

Please sign in to comment.