Skip to content

Commit

Permalink
Avoid listener race collecting wrong plan in assert_gpu_fallback_coll…
Browse files Browse the repository at this point in the history
…ect (NVIDIA#2516)

Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored May 27, 2021
1 parent 6d14296 commit 12adfb5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
19 changes: 12 additions & 7 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ def with_limit(spark):
elif mode == 'COUNT':
bring_back = lambda spark: limit_func(spark).count()
collect_type = 'COUNT'
elif mode == 'COLLECT_WITH_DATAFRAME':
def bring_back(spark):
df = limit_func(spark)
return (df.collect(), df)
collect_type = 'COLLECT'
return (bring_back, collect_type)
else:
bring_back = lambda spark: limit_func(spark).toLocalIterator()
collect_type = 'ITERATOR'
Expand Down Expand Up @@ -292,21 +298,20 @@ def assert_gpu_fallback_write(write_func,
def assert_gpu_fallback_collect(func,
cpu_fallback_class_name,
conf={}):
(bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT')
(bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME')

conf = _prep_incompat_conf(conf)

print('### CPU RUN ###')
cpu_start = time.time()
from_cpu = with_cpu_session(bring_back, conf=conf)
from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf)
cpu_end = time.time()
print('### GPU RUN ###')
jvm = spark_jvm()
jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.startCapture()
gpu_start = time.time()
from_gpu = with_gpu_session(bring_back,
conf=conf)
from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf)
gpu_end = time.time()
jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name, 2000)
jvm = spark_jvm()
jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertDidFallBack(gpu_df._jdf, cpu_fallback_class_name)
print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,
gpu_end - gpu_start, cpu_end - cpu_start))
if should_sort_locally():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.{DataFrame, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -322,6 +322,11 @@ object ExecutionPlanCaptureCallback {
s"Could not find $fallbackCpuClass in the GPU plan\n$executedPlan")
}

def assertDidFallBack(df: DataFrame, fallbackCpuClass: String): Unit = {
val executedPlan = df.queryExecution.executedPlan
assertDidFallBack(executedPlan, fallbackCpuClass)
}

private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = {
!exp.isInstanceOf[GpuExpression] &&
PlanUtils.getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass ||
Expand Down

0 comments on commit 12adfb5

Please sign in to comment.