diff --git a/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala b/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala index 6c98240db3f..8fc0bf4d9ae 100644 --- a/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala +++ b/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/shims/spark311db/Spark311dbShims.scala @@ -20,6 +20,7 @@ import com.databricks.sql.execution.window.RunningWindowFunctionExec import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.shims.spark311.Spark311Shims import org.apache.hadoop.fs.Path +import org.apache.parquet.schema.MessageType import org.apache.spark.sql.rapids.shims.spark311db._ import org.apache.spark.rdd.RDD @@ -34,11 +35,13 @@ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec} import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, MapInPandasExec, WindowInPandasExec} import org.apache.spark.sql.execution.window.WindowExecBase +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.GpuFileSourceScanExec import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase} import org.apache.spark.sql.rapids.execution.python.{GpuPythonUDF, GpuWindowInPandasExecMetaBase} @@ -47,6 +50,18 @@ import org.apache.spark.sql.types._ class Spark311dbShims extends Spark311Shims { + override def getParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters = + new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, + pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) + override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION override def getGpuBroadcastNestedLoopJoinShim(