Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix databricks 3.0.1 for ParquetFilters api change #3118

Merged
merged 2 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids._
import org.apache.arrow.memory.ReferenceManager
import org.apache.arrow.vector.ValueVector
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.SparkEnv
import org.apache.spark.rdd.RDD
Expand All @@ -41,6 +42,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand Down Expand Up @@ -80,6 +82,19 @@ abstract class SparkBaseShims extends SparkShims {
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE)

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters = {
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive)
}

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
AlterTableRecoverPartitionsCommand(tableName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.databricks.sql.execution.window.RunningWindowFunctionExec
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark301.Spark301Shims
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.sql.rapids.shims.spark301db._
import org.apache.spark.rdd.RDD
Expand All @@ -33,11 +34,13 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, MapInPandasExec, WindowInPandasExec}
import org.apache.spark.sql.execution.window.WindowExecBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuFileSourceScanExec
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.rapids.execution.python.{GpuAggregateInPandasExecMeta, GpuArrowEvalPythonExec, GpuFlatMapGroupsInPandasExecMeta, GpuMapInPandasExecMeta, GpuPythonUDF, GpuWindowInPandasExecMetaBase}
Expand All @@ -47,6 +50,18 @@ class Spark301dbShims extends Spark301Shims {

override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode)

override def getGpuBroadcastNestedLoopJoinShim(
left: SparkPlan,
right: SparkPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package com.nvidia.spark.rapids.shims.spark313
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark312.Spark312Shims
import com.nvidia.spark.rapids.spark313.RapidsShuffleManager
import org.apache.parquet.schema.MessageType

import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.internal.SQLConf

class Spark313Shims extends Spark312Shims {

Expand All @@ -27,4 +31,16 @@ class Spark313Shims extends Spark312Shims {
override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode)
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ class Spark320Shims extends Spark311Shims {
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE)

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode, datetimeRebaseMode)

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
RepairTableCommand(tableName,
// These match the one place that this is called, if we start to call this in more places
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, ParquetReadSupport}
import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
Expand Down Expand Up @@ -257,8 +257,8 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
private val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
private val isCorrectedRebase =
"CORRECTED" == ShimLoader.getSparkShims.parquetRebaseRead(sqlConf)
private val rebaseMode = ShimLoader.getSparkShims.parquetRebaseRead(sqlConf)
private val isCorrectedRebase = "CORRECTED" == rebaseMode

def filterBlocks(
file: PartitionedFile,
Expand All @@ -272,8 +272,11 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
ParquetMetadataConverter.range(file.start, file.start + file.length))
val fileSchema = footer.getFileMetaData.getSchema
val pushedFilters = if (enableParquetFilterPushDown) {
val parquetFilters = new ParquetFilters(fileSchema, pushDownDate, pushDownTimestamp,
pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
footer.getFileMetaData.getKeyValueMetaData.get, rebaseMode)
val parquetFilters = ShimLoader.getSparkShims.getParquetFilters(fileSchema, pushDownDate,
pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold,
isCaseSensitive, datetimeRebaseMode)
filters.flatMap(parquetFilters.createFilter).reduceOption(FilterApi.and)
} else {
None
Expand Down
13 changes: 13 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import org.apache.arrow.memory.ReferenceManager
import org.apache.arrow.vector.ValueVector
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
Expand All @@ -40,9 +41,11 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.sources.BaseRelation
Expand Down Expand Up @@ -88,6 +91,16 @@ trait SparkShims {
def parquetRebaseWrite(conf: SQLConf): String
def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand

def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: LegacyBehaviorPolicy.Value): ParquetFilters

def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean
def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
def isWindowFunctionExec(plan: SparkPlan): Boolean
Expand Down