diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index 585e1fd5d34..c7ef15cfdea 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -186,6 +186,12 @@ def with_limit(spark): elif mode == 'COUNT': bring_back = lambda spark: limit_func(spark).count() collect_type = 'COUNT' + elif mode == 'COLLECT_WITH_DATAFRAME': + def bring_back(spark): + df = limit_func(spark) + return (df.collect(), df) + collect_type = 'COLLECT' + return (bring_back, collect_type) else: bring_back = lambda spark: limit_func(spark).toLocalIterator() collect_type = 'ITERATOR' @@ -292,21 +298,20 @@ def assert_gpu_fallback_write(write_func, def assert_gpu_fallback_collect(func, cpu_fallback_class_name, conf={}): - (bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT') + (bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME') + conf = _prep_incompat_conf(conf) print('### CPU RUN ###') cpu_start = time.time() - from_cpu = with_cpu_session(bring_back, conf=conf) + from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf) cpu_end = time.time() print('### GPU RUN ###') - jvm = spark_jvm() - jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.startCapture() gpu_start = time.time() - from_gpu = with_gpu_session(bring_back, - conf=conf) + from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf) gpu_end = time.time() - jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name, 2000) + jvm = spark_jvm() + jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertDidFallBack(gpu_df._jdf, cpu_fallback_class_name) print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type, gpu_end - gpu_start, cpu_end - cpu_start)) if should_sort_locally(): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 4eaa03676c6..a94bc03a5cc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext} import org.apache.spark.internal.Logging import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.{DataFrame, SparkSessionExtensions} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ @@ -322,6 +322,11 @@ object ExecutionPlanCaptureCallback { s"Could not find $fallbackCpuClass in the GPU plan\n$executedPlan") } + def assertDidFallBack(df: DataFrame, fallbackCpuClass: String): Unit = { + val executedPlan = df.queryExecution.executedPlan + assertDidFallBack(executedPlan, fallbackCpuClass) + } + private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = { !exp.isInstanceOf[GpuExpression] && PlanUtils.getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass ||