Skip to content

Commit

Permalink
Fix regression in cost-based optimizer when calculating cost for Wind…
Browse files Browse the repository at this point in the history
…ow operations (NVIDIA#2491)
  • Loading branch information
andygrove authored May 25, 2021
1 parent b61389c commit 8100500
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import scala.collection.mutable.ListBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetStructField}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetStructField, WindowFrame, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftAnti, LeftSemi}
import org.apache.spark.sql.execution.{GlobalLimitExec, LocalLimitExec, SparkPlan, TakeOrderedAndProjectExec, UnionExec}
import org.apache.spark.sql.execution.adaptive.{CustomShuffleReaderExec, QueryStageExec}
Expand Down Expand Up @@ -262,6 +262,9 @@ class CpuCostModel(conf: RapidsConf) extends CostModel {
}

private def exprCost[INPUT <: Expression](expr: BaseExprMeta[INPUT], rowCount: Double): Double = {
if (MemoryCostHelper.isExcludedFromCost(expr)) {
return 0d
}

val memoryReadCost = expr.wrapped match {
case _: Alias =>
Expand Down Expand Up @@ -308,6 +311,9 @@ class GpuCostModel(conf: RapidsConf) extends CostModel {
}

private def exprCost[INPUT <: Expression](expr: BaseExprMeta[INPUT], rowCount: Double): Double = {
if (MemoryCostHelper.isExcludedFromCost(expr)) {
return 0d
}

var memoryReadCost = 0d
var memoryWriteCost = 0d
Expand Down Expand Up @@ -351,6 +357,15 @@ object MemoryCostHelper {
def calculateCost(dataSize: Long, memorySpeed: Double): Double = {
(dataSize / GIGABYTE) / memorySpeed
}

def isExcludedFromCost[INPUT <: Expression](expr: BaseExprMeta[INPUT]) = {
expr.wrapped match {
case _: WindowSpecDefinition | _: WindowFrame =>
// Window expressions are Unevaluable and accessing dataType causes an exception
true
case _ => false
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,38 @@ class CostBasedOptimizerSuite extends SparkQueryCompareTestSuite
assert(broadcastExchanges.forall(_._2.toLong > 0))
}

// The purpose of this test is simply to make sure that we do not try and calculate the cost
// of expressions that cannot be evaluated. This can cause exceptions when estimating data
// sizes. For example, WindowFrame.dataType throws UnsupportedOperationException.
test("Window expression (unevaluable)") {
logError("Window expression (unevaluable)")

val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(RapidsConf.OPTIMIZER_ENABLED.key, "true")
.set(RapidsConf.TEST_ALLOWED_NONGPU.key,
"ProjectExec,WindowExec,ShuffleExchangeExec,WindowSpecDefinition," +
"SpecifiedWindowFrame,WindowExpression,Alias,Rank,HashPartitioning," +
"UnboundedPreceding$,CurrentRow$,SortExec")

withGpuSparkSession(spark => {
employeeDf(spark).createOrReplaceTempView("employees")
val df: DataFrame = spark.sql("SELECT name, dept, RANK() " +
"OVER (PARTITION BY dept ORDER BY salary) AS rank FROM employees")
df.collect()
}, conf)
}

private def employeeDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq(
("A", "IT", 1234),
("B", "IT", 4321),
("C", "Sales", 4321),
("D", "Sales", 1234)
).toDF("name", "dept", "salary")
}

private def collectPlansWithRowCount(
plan: SparkPlanMeta[_],
accum: ListBuffer[SparkPlanMeta[_]]): Unit = {
Expand Down

0 comments on commit 8100500

Please sign in to comment.