Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explain Plugin API for CPU plan #3850

Merged
merged 67 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
413872a
Start making the Qualification tool programmatically callable
tgravescs Sep 27, 2021
2062e66
remove unneeded numRows
tgravescs Sep 27, 2021
c67daff
add test
tgravescs Sep 27, 2021
030a167
handle the listener rest of events
tgravescs Sep 27, 2021
1cd2a65
create RunningQualApp
tgravescs Sep 27, 2021
8943e54
update test and start looking at api for explain
tgravescs Sep 28, 2021
b4147d4
copyright
tgravescs Sep 30, 2021
2f9b9dc
refactor writer so we can get output strings
tgravescs Oct 1, 2021
e14a844
refactor applyOverrides and add config to explain only
tgravescs Oct 1, 2021
d6d458e
fix param
tgravescs Oct 1, 2021
6ffe9a0
fix param headerCSV
tgravescs Oct 1, 2021
aea7a4c
Add explain function
tgravescs Oct 5, 2021
8d28954
refactor
tgravescs Oct 5, 2021
3a9f237
try shimloader
tgravescs Oct 5, 2021
a55d9dd
expose GpuExplainPlan
tgravescs Oct 5, 2021
98ab5bb
rename ExplainPlan
tgravescs Oct 5, 2021
302dd7a
debug
tgravescs Oct 5, 2021
748fee4
working except subquery
tgravescs Oct 6, 2021
3379c00
add expressions subqueries
tgravescs Oct 6, 2021
ff0f284
get subquery plans
tgravescs Oct 6, 2021
aea57ff
fix
tgravescs Oct 6, 2021
e116aba
find
tgravescs Oct 6, 2021
e3e24d4
handle children
tgravescs Oct 6, 2021
83e1541
print
tgravescs Oct 6, 2021
e314883
fix plans
tgravescs Oct 6, 2021
44d12ce
working subqueries
tgravescs Oct 7, 2021
6d83bcb
Merge remote-tracking branch 'origin/branch-21.12' into qualCallable
tgravescs Oct 7, 2021
a57ca73
comments
tgravescs Oct 7, 2021
b9e7e8a
cleanup
tgravescs Oct 7, 2021
f9c2022
rework overrides
tgravescs Oct 7, 2021
8755a86
fix missing conf
tgravescs Oct 7, 2021
4472898
Add separate file explain gpu plan
tgravescs Oct 7, 2021
f6e0f0b
Merge remote-tracking branch 'origin/branch-21.12' into qualCallable2
tgravescs Oct 7, 2021
5306a81
put stuff back in overrides
tgravescs Oct 7, 2021
2772ebb
start shim aqe get init plan
tgravescs Oct 7, 2021
a681645
add shims
tgravescs Oct 8, 2021
7b5d5bf
Merge remote-tracking branch 'origin/branch-21.12' into qualCallable2
tgravescs Oct 8, 2021
c46a33f
remove config
tgravescs Oct 8, 2021
da4fa3e
remove qualificationt tool changes
tgravescs Oct 8, 2021
c9ecd10
add javadoc
tgravescs Oct 8, 2021
394e17d
remove file not needed
tgravescs Oct 8, 2021
0729fc7
updates to explain parameter
tgravescs Oct 8, 2021
ac1513a
change to string
tgravescs Oct 8, 2021
6ef2e1b
rework
tgravescs Oct 8, 2021
bf12fac
rename ExplainGPUPlan
tgravescs Oct 11, 2021
874e4f7
Update name
tgravescs Oct 11, 2021
52f3500
Merge remote-tracking branch 'origin/branch-21.12' into qualExplain
tgravescs Oct 14, 2021
4fdb4e9
Merge remote-tracking branch 'origin/branch-21.12' into qualExplain
tgravescs Oct 18, 2021
bd4bbeb
update javadoc
tgravescs Oct 18, 2021
f51be6d
revert docs
tgravescs Oct 18, 2021
9499fcb
update doc
tgravescs Oct 18, 2021
f9901c6
Add pytest for explain api
tgravescs Oct 18, 2021
268502c
more tests
tgravescs Oct 18, 2021
f64a3e0
Merge remote-tracking branch 'origin/branch-21.12' into qualExplain
tgravescs Oct 18, 2021
d4f9718
add python api example
tgravescs Oct 18, 2021
b4466ee
remove import
tgravescs Oct 18, 2021
e14fb76
change comment
tgravescs Oct 18, 2021
edf9253
revert docs
tgravescs Oct 18, 2021
f85c530
change names
tgravescs Oct 18, 2021
0565cf0
change func name
tgravescs Oct 18, 2021
a09dc0e
revert docs
tgravescs Oct 18, 2021
a3087e7
Merge remote-tracking branch 'origin/branch-21.12' into qualExplain
tgravescs Oct 18, 2021
437abb8
revert docs
tgravescs Oct 18, 2021
64475af
Add a test for setting rapids conf before calling explain
tgravescs Oct 19, 2021
456201e
review comments
tgravescs Oct 19, 2021
3640aff
update docs
tgravescs Oct 19, 2021
4f86cba
Define what we throw and update to not use reflection
tgravescs Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dist/unshimmed-common-from-spark301.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ com/nvidia/spark/RapidsUDF*
com/nvidia/spark/SQLPlugin*
com/nvidia/spark/rapids/ColumnarRdd*
com/nvidia/spark/rapids/ExecutionPlanCaptureCallback*
com/nvidia/spark/rapids/ExplainPlan*
com/nvidia/spark/rapids/GpuKryoRegistrator*
com/nvidia/spark/rapids/PlanUtils*
com/nvidia/spark/rapids/RapidsExecutorHeartbeatMsg*
Expand Down
93 changes: 93 additions & 0 deletions integration_tests/src/main/python/explain_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from data_gen import *
from marks import *
from pyspark.sql.functions import *
from pyspark.sql.types import *
from spark_session import with_cpu_session

def create_df(spark, data_gen, left_length, right_length):
left = binary_op_df(spark, data_gen, length=left_length)
right = binary_op_df(spark, data_gen, length=right_length).withColumnRenamed("a", "r_a")\
.withColumnRenamed("b", "r_b")
return left, right


@pytest.mark.parametrize('data_gen', [StringGen()], ids=idfn)
def test_explain_join(spark_tmp_path, data_gen):
data_path1 = spark_tmp_path + '/PARQUET_DATA1'
data_path2 = spark_tmp_path + '/PARQUET_DATA2'

def do_join_explain(spark):
left, right = create_df(spark, data_gen, 500, 500)
left.write.parquet(data_path1)
right.write.parquet(data_path2)
df1 = spark.read.parquet(data_path1)
df2 = spark.read.parquet(data_path2)
df3 = df1.join(df2, df1.a == df2.r_a, "inner")
explain_str = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df3._jdf, "ALL")
remove_isnotnull = explain_str.replace("isnotnull", "")
# everything should be on GPU
assert "not" not in remove_isnotnull

with_cpu_session(do_join_explain)

def test_explain_set_config():
conf = {'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.rapids.sql.castStringToTimestamp.enabled': 'true'}

def do_explain(spark):
df = unary_op_df(spark, StringGen('[0-9]{1,4}-[0-9]{1,2}-[0-9]{1,2}')).select(f.col('a').cast(TimestampType()))
# a bit brittle if these get turned on by default
spark.conf.set('spark.rapids.sql.hasExtendedYearValues', 'false')
spark.conf.set('spark.rapids.sql.castStringToTimestamp.enabled', 'true')
explain_str = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df._jdf, "ALL")
print(explain_str)
assert "timestamp) will run on GPU" in explain_str
spark.conf.set('spark.rapids.sql.castStringToTimestamp.enabled', 'false')
explain_str_cast_off = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df._jdf, "ALL")
print(explain_str_cast_off)
assert "timestamp) cannot run on GPU" in explain_str_cast_off

with_cpu_session(do_explain)

def test_explain_udf():
slen = udf(lambda s: len(s), IntegerType())

@udf
def to_upper(s):
if s is not None:
return s.upper()

@udf(returnType=IntegerType())
def add_one(x):
if x is not None:
return x + 1

def do_explain(spark):
df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
df2 = df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age"))
explain_str = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df2._jdf, "ALL")
# udf shouldn't be on GPU
udf_str_not = 'cannot run on GPU because no GPU enabled version of operator class org.apache.spark.sql.execution.python.BatchEvalPythonExec'
assert udf_str_not in explain_str
not_on_gpu_str = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df2._jdf, "NOT")
assert udf_str_not in not_on_gpu_str
assert "will run on GPU" not in not_on_gpu_str

with_cpu_session(do_explain)

Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,10 @@ class Spark320Shims extends Spark32XShims {
new KryoJavaSerializer())
}

override def getAdaptiveInputPlan(adaptivePlan: AdaptiveSparkPlanExec): SparkPlan = {
adaptivePlan.initialPlan
}

override def getLegacyStatisticalAggregate(): Boolean =
SQLConf.get.legacyStatisticalAggregate
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
Expand Down Expand Up @@ -688,5 +688,9 @@ abstract class SparkBaseShims extends Spark30XShims {
new KryoJavaSerializer())
}

override def getAdaptiveInputPlan(adaptivePlan: AdaptiveSparkPlanExec): SparkPlan = {
adaptivePlan.initialPlan
}

override def getLegacyStatisticalAggregate(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
Expand Down Expand Up @@ -647,4 +647,8 @@ abstract class SparkBaseShims extends Spark30XShims {
kryo.register(classOf[SerializeBatchDeserializeHostBuffer],
new KryoJavaSerializer())
}

override def getAdaptiveInputPlan(adaptivePlan: AdaptiveSparkPlanExec): SparkPlan = {
adaptivePlan.initialPlan
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -874,6 +874,10 @@ abstract class SparkBaseShims extends Spark30XShims {

override def shouldFallbackOnAnsiTimestamp(): Boolean = SQLConf.get.ansiEnabled

override def getAdaptiveInputPlan(adaptivePlan: AdaptiveSparkPlanExec): SparkPlan = {
adaptivePlan.inputPlan
}

override def getLegacyStatisticalAggregate(): Boolean =
SQLConf.get.legacyStatisticalAggregate
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -845,6 +845,10 @@ abstract class SparkBaseShims extends Spark31XShims {

override def shouldFallbackOnAnsiTimestamp(): Boolean = SQLConf.get.ansiEnabled

override def getAdaptiveInputPlan(adaptivePlan: AdaptiveSparkPlanExec): SparkPlan = {
adaptivePlan.inputPlan
}

override def getLegacyStatisticalAggregate(): Boolean =
SQLConf.get.legacyStatisticalAggregate
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids

import scala.util.control.NonFatal

import org.apache.spark.sql.DataFrame

// Base trait visible publicly outside of parallel world packaging.
// It can't be named the same as ExplainPlan object to allow calling from PySpark.
trait ExplainPlanBase {
def explainPotentialGpuPlan(df: DataFrame, explain: String = "ALL"): String
}

object ExplainPlan {
/**
* Looks at the CPU plan associated with the dataframe and outputs information
* about which parts of the query the RAPIDS Accelerator for Apache Spark
* could place on the GPU. This only applies to the initial plan, so if running
* with adaptive query execution enable, it will not be able to show any changes
* in the plan due to that.
*
* This is very similar output you would get by running the query with the
* Rapids Accelerator enabled and with the config `spark.rapids.sql.enabled` enabled.
*
* Requires the RAPIDS Accelerator for Apache Spark jar and RAPIDS cudf jar be included
* in the classpath but the RAPIDS Accelerator for Apache Spark should be disabled.
*
* {{{
* val output = com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df)
* }}}
*
* Calling from PySpark:
*
* {{{
* output = sc._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df._jdf, "ALL")
* }}}
*
* @param df The Spark DataFrame to get the query plan from
* @param explain If ALL returns all the explain data, otherwise just returns what does not
* work on the GPU. Default is ALL.
* @return String containing the explained plan.
* @throws IllegalArgumentException if an argument is invalid or it is unable to determine the
* Spark version
* @throws IllegalStateException if the plugin gets into an invalid state while trying
* to process the plan or there is an unexepected exception.
*/
@throws[IllegalArgumentException]
@throws[IllegalStateException]
def explainPotentialGpuPlan(df: DataFrame, explain: String = "ALL"): String = {
try {
ShimLoader.newExplainPlan.explainPotentialGpuPlan(df, explain)
} catch {
case ia: IllegalArgumentException => throw ia
case is: IllegalStateException => throw is
case NonFatal(e) =>
val msg = "Unexpected exception trying to run explain on the plan!"
throw new IllegalStateException(msg, e)
}
}
}
Loading