Skip to content

Commit

Permalink
Fix ParquetFilters issue (#2961)
Browse files Browse the repository at this point in the history
* Fix ParquetFilters issue

Spark has added datetime rebase mode into constructor of ParquetFilters,
see SPARK-36034, which causes the build failed.

This change is adding getParquetFilters in the shim layers to match
different Spark version.

Signed-off-by: Bobby Wang <wbo4958@gmail.com>

* Fix compiling error
  • Loading branch information
wbo4958 authored Jul 19, 2021
1 parent fbc5506 commit 9bec142
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids._
import org.apache.arrow.memory.ReferenceManager
import org.apache.arrow.vector.ValueVector
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.SparkEnv
import org.apache.spark.rdd.RDD
Expand All @@ -41,6 +42,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand Down Expand Up @@ -80,6 +82,19 @@ abstract class SparkBaseShims extends SparkShims {
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE)

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters = {
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive)
}

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
AlterTableRecoverPartitionsCommand(tableName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package com.nvidia.spark.rapids.shims.spark313
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark312.Spark312Shims
import com.nvidia.spark.rapids.spark313.RapidsShuffleManager
import org.apache.parquet.schema.MessageType

import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.internal.SQLConf

class Spark313Shims extends Spark312Shims {

Expand All @@ -27,4 +31,16 @@ class Spark313Shims extends Spark312Shims {
override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode)
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ class Spark320Shims extends Spark311Shims {
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE)

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode, datetimeRebaseMode)

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
RepairTableCommand(tableName,
// These match the one place that this is called, if we start to call this in more places
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{PartitionedFile}
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, ParquetReadSupport}
import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand Down Expand Up @@ -250,8 +250,8 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
private val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
private val isCorrectedRebase =
"CORRECTED" == ShimLoader.getSparkShims.parquetRebaseRead(sqlConf)
private val rebaseMode = ShimLoader.getSparkShims.parquetRebaseRead(sqlConf)
private val isCorrectedRebase = "CORRECTED" == rebaseMode

def filterBlocks(
file: PartitionedFile,
Expand All @@ -265,8 +265,11 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
ParquetMetadataConverter.range(file.start, file.start + file.length))
val fileSchema = footer.getFileMetaData.getSchema
val pushedFilters = if (enableParquetFilterPushDown) {
val parquetFilters = new ParquetFilters(fileSchema, pushDownDate, pushDownTimestamp,
pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
footer.getFileMetaData.getKeyValueMetaData.get, rebaseMode)
val parquetFilters = ShimLoader.getSparkShims.getParquetFilters(fileSchema, pushDownDate,
pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold,
isCaseSensitive, datetimeRebaseMode)
filters.flatMap(parquetFilters.createFilter).reduceOption(FilterApi.and)
} else {
None
Expand Down
13 changes: 13 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import org.apache.arrow.memory.ReferenceManager
import org.apache.arrow.vector.ValueVector
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
Expand All @@ -40,9 +41,11 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.sources.BaseRelation
Expand Down Expand Up @@ -88,6 +91,16 @@ trait SparkShims {
def parquetRebaseWrite(conf: SQLConf): String
def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand

def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: LegacyBehaviorPolicy.Value): ParquetFilters

def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean
def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
def isWindowFunctionExec(plan: SparkPlan): Boolean
Expand Down

0 comments on commit 9bec142

Please sign in to comment.