Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove schema utils, case class copying, file partition, and legacy statistical aggregate shims [databricks] #5007

Merged
merged 5 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import org.apache.spark.rapids.shims.GpuShuffleExchangeExec
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors.attachTree
Expand All @@ -62,7 +61,7 @@ import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.execution.{GpuCustomShuffleReaderExec, GpuShuffleExchangeExecBase, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.rapids.execution.python._
import org.apache.spark.sql.rapids.execution.python.shims._
import org.apache.spark.sql.rapids.shims.{GpuColumnarToRowTransitionExec, GpuSchemaUtils, HadoopFSUtilsShim}
import org.apache.spark.sql.rapids.shims.{GpuColumnarToRowTransitionExec, HadoopFSUtilsShim}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -174,47 +173,6 @@ abstract class Spark31XShims extends SparkShims with Spark31Xuntil33XShims with
(a, conf, p, r) => new RapidsOrcScanMeta(a, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap

override def getPartitionFileNames(
partitions: Seq[PartitionDirectory]): Seq[String] = {
val files = partitions.flatMap(partition => partition.files)
files.map(_.getPath.getName)
}

override def getPartitionFileStatusSize(partitions: Seq[PartitionDirectory]): Long = {
partitions.map(_.files.map(_.getLen).sum).sum
}

override def getPartitionedFiles(
partitions: Array[PartitionDirectory]): Array[PartitionedFile] = {
partitions.flatMap { p =>
p.files.map { f =>
PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
}
}
}

override def getPartitionSplitFiles(
partitions: Array[PartitionDirectory],
maxSplitBytes: Long,
relation: HadoopFsRelation): Array[PartitionedFile] = {
partitions.flatMap { partition =>
partition.files.flatMap { file =>
// getPath() is very expensive so we only want to call it once in this block:
val filePath = file.getPath
val isSplitable = relation.fileFormat.isSplitable(
relation.sparkSession, relation.options, filePath)
PartitionedFileUtil.splitFiles(
sparkSession = relation.sparkSession,
file = file,
filePath = filePath,
isSplitable = isSplitable,
maxSplitBytes = maxSplitBytes,
partitionValues = partition.values
)
}
}
}

override def getFileScanRDD(
sparkSession: SparkSession,
readFunction: PartitionedFile => Iterator[InternalRow],
Expand All @@ -224,43 +182,6 @@ abstract class Spark31XShims extends SparkShims with Spark31Xuntil33XShims with
new FileScanRDD(sparkSession, readFunction, filePartitions)
}

override def createFilePartition(index: Int, files: Array[PartitionedFile]): FilePartition = {
FilePartition(index, files)
}

override def copyBatchScanExec(
batchScanExec: GpuBatchScanExec,
queryUsesInputFile: Boolean): GpuBatchScanExec = {
val scanCopy = batchScanExec.scan match {
case parquetScan: GpuParquetScan =>
parquetScan.copy(queryUsesInputFile = queryUsesInputFile)
case orcScan: GpuOrcScan =>
orcScan.copy(queryUsesInputFile = queryUsesInputFile)
case _ => throw new RuntimeException("Wrong format") // never reach here
}
batchScanExec.copy(scan = scanCopy)
}

override def copyFileSourceScanExec(
scanExec: GpuFileSourceScanExec,
queryUsesInputFile: Boolean): GpuFileSourceScanExec = {
scanExec.copy(queryUsesInputFile = queryUsesInputFile)(scanExec.rapidsConf)
}

override def checkColumnNameDuplication(
schema: StructType,
colType: String,
resolver: Resolver): Unit = {
GpuSchemaUtils.checkColumnNameDuplication(schema, colType, resolver)
}

override def alias(child: Expression, name: String)(
exprId: ExprId,
qualifier: Seq[String],
explicitMetadata: Option[Metadata]): Alias = {
Alias(child, name)(exprId, qualifier, explicitMetadata)
}

override def getArrowValidityBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = {
val arrowBuf = vec.getValidityBuffer
(arrowBuf.nioBuffer(), arrowBuf.getReferenceManager)
Expand Down Expand Up @@ -892,9 +813,6 @@ abstract class Spark31XShims extends SparkShims with Spark31Xuntil33XShims with
adaptivePlan.inputPlan
}

override def getLegacyStatisticalAggregate(): Boolean =
SQLConf.get.legacyStatisticalAggregate

override def hasCastFloatTimestampUpcast: Boolean = false

override def isCastingStringToNegDecimalScaleSupported: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.rapids.shims.GpuShuffleExchangeExec
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors.attachTree
Expand Down Expand Up @@ -523,47 +522,6 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
(a, conf, p, r) => new RapidsOrcScanMeta(a, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap

override def getPartitionFileNames(
partitions: Seq[PartitionDirectory]): Seq[String] = {
val files = partitions.flatMap(partition => partition.files)
files.map(_.getPath.getName)
}

override def getPartitionFileStatusSize(partitions: Seq[PartitionDirectory]): Long = {
partitions.map(_.files.map(_.getLen).sum).sum
}

override def getPartitionedFiles(
partitions: Array[PartitionDirectory]): Array[PartitionedFile] = {
partitions.flatMap { p =>
p.files.map { f =>
PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
}
}
}

override def getPartitionSplitFiles(
partitions: Array[PartitionDirectory],
maxSplitBytes: Long,
relation: HadoopFsRelation): Array[PartitionedFile] = {
partitions.flatMap { partition =>
partition.files.flatMap { file =>
// getPath() is very expensive so we only want to call it once in this block:
val filePath = file.getPath
val isSplitable = relation.fileFormat.isSplitable(
relation.sparkSession, relation.options, filePath)
PartitionedFileUtil.splitFiles(
sparkSession = relation.sparkSession,
file = file,
filePath = filePath,
isSplitable = isSplitable,
maxSplitBytes = maxSplitBytes,
partitionValues = partition.values
)
}
}
}

override def getFileScanRDD(
sparkSession: SparkSession,
readFunction: PartitionedFile => Iterator[InternalRow],
Expand All @@ -573,29 +531,6 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
new GpuFileScanRDD(sparkSession, readFunction, filePartitions)
}

override def createFilePartition(index: Int, files: Array[PartitionedFile]): FilePartition = {
FilePartition(index, files)
}

override def copyBatchScanExec(
batchScanExec: GpuBatchScanExec,
queryUsesInputFile: Boolean): GpuBatchScanExec = {
val scanCopy = batchScanExec.scan match {
case parquetScan: GpuParquetScan =>
parquetScan.copy(queryUsesInputFile=queryUsesInputFile)
case orcScan: GpuOrcScan =>
orcScan.copy(queryUsesInputFile=queryUsesInputFile)
case _ => throw new RuntimeException("Wrong format") // never reach here
}
batchScanExec.copy(scan=scanCopy)
}

override def copyFileSourceScanExec(
scanExec: GpuFileSourceScanExec,
queryUsesInputFile: Boolean): GpuFileSourceScanExec = {
scanExec.copy(queryUsesInputFile=queryUsesInputFile)(scanExec.rapidsConf)
}

override def getGpuColumnarToRowTransition(plan: SparkPlan,
exportColumnRdd: Boolean): GpuColumnarToRowExecParent = {
val serName = plan.conf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER)
Expand All @@ -607,13 +542,6 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
}
}

override def checkColumnNameDuplication(
schema: StructType,
colType: String,
resolver: Resolver): Unit = {
GpuSchemaUtils.checkColumnNameDuplication(schema, colType, resolver)
}

override def getGpuShuffleExchangeExec(
gpuOutputPartitioning: GpuPartitioning,
child: SparkPlan,
Expand All @@ -637,13 +565,6 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
s.copy(child = child)
}

override def alias(child: Expression, name: String)(
exprId: ExprId,
qualifier: Seq[String],
explicitMetadata: Option[Metadata]): Alias = {
Alias(child, name)(exprId, qualifier, explicitMetadata)
}

override def shouldIgnorePath(path: String): Boolean = {
HadoopFSUtilsShim.shouldIgnorePath(path)
}
Expand Down Expand Up @@ -817,9 +738,6 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
adaptivePlan.inputPlan
}

override def getLegacyStatisticalAggregate(): Boolean =
SQLConf.get.legacyStatisticalAggregate

override def supportsColumnarAdaptivePlans: Boolean = false

override def columnarAdaptivePlan(a: AdaptiveSparkPlanExec, goal: CoalesceSizeGoal): SparkPlan = {
Expand Down
Loading