Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ShimVectorizedColumnReader construction for recent Spark 3.3.0 changes #5124

Merged
merged 4 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@
<source>${project.basedir}/src/main/320+-nondb/scala</source>
<source>${project.basedir}/src/main/320until330-all/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/321until330-all/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
</configuration>
Expand Down Expand Up @@ -524,8 +525,9 @@
<source>${project.basedir}/src/main/311until330-nondb/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/320+-nondb/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/320until330-all/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/321until330-all/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
</configuration>
Expand Down Expand Up @@ -597,8 +599,9 @@
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/311+-db/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/320until330-all/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/321until330-all/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet.rapids.shims

import java.time.ZoneId

import org.apache.parquet.VersionParser.ParsedVersion
import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.schema.{GroupType, Type}
Expand Down Expand Up @@ -53,7 +54,8 @@ class ShimVectorizedColumnReader(
convertTz: ZoneId,
datetimeRebaseMode: String,
int96RebaseMode: String,
int96CDPHive3Compatibility: Boolean
int96CDPHive3Compatibility: Boolean,
writerVersion: ParsedVersion
) extends VectorizedColumnReader(
columns.get(index),
types.get(index).getOriginalType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet.rapids.shims

import java.time.ZoneId

import org.apache.parquet.VersionParser.ParsedVersion
import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.schema.{GroupType, Type}
Expand Down Expand Up @@ -52,7 +53,8 @@ class ShimVectorizedColumnReader(
convertTz: ZoneId,
datetimeRebaseMode: String,
int96RebaseMode: String,
int96CDPHive3Compatibility: Boolean
int96CDPHive3Compatibility: Boolean,
writerVersion: ParsedVersion
) extends VectorizedColumnReader(
columns.get(index),
types.get(index).getOriginalType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
new ArrayIndexOutOfBoundsException(s"Invalid index: $index, numElements: $numElements")
}

def mapKeyNotExistError(key: String, isElementAtF: Boolean = false): NoSuchElementException = {
def mapKeyNotExistError(
key: String,
isElementAtFunction: Boolean = false,
origin: Origin): NoSuchElementException = {
// Follow the Spark string format before 3.3.0
new NoSuchElementException(s"Key $key does not exist.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
new ArrayIndexOutOfBoundsException(s"Invalid index: $index, numElements: $numElements")
}

def mapKeyNotExistError(key: String, isElementAtF: Boolean = false): NoSuchElementException = {
def mapKeyNotExistError(
key: String,
isElementAtFunction: Boolean,
origin: Origin): NoSuchElementException = {
// Follow the Spark string format before 3.3.0
new NoSuchElementException(s"Key $key does not exist.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet.rapids.shims

import java.time.ZoneId

import org.apache.parquet.VersionParser.ParsedVersion
import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.schema.{GroupType, Type}
Expand Down Expand Up @@ -52,7 +53,8 @@ class ShimVectorizedColumnReader(
convertTz: ZoneId,
datetimeRebaseMode: String,
int96RebaseMode: String,
int96CDPHive3Compatibility: Boolean
int96CDPHive3Compatibility: Boolean,
writerVersion: ParsedVersion
) extends VectorizedColumnReader(
columns.get(index),
types.get(index).getLogicalTypeAnnotation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors

object RapidsErrorUtils {
Expand All @@ -28,7 +29,10 @@ object RapidsErrorUtils {
}
}

def mapKeyNotExistError(key: String, isElementAtF: Boolean = false): NoSuchElementException = {
def mapKeyNotExistError(
key: String,
isElementAtFunction: Boolean,
origin: Origin): NoSuchElementException = {
// For now, the default argument is false. The caller sets the correct value accordingly.
QueryExecutionErrors.mapKeyNotExistError(key)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet.rapids.shims
import java.time.ZoneId
import java.util.TimeZone

import org.apache.parquet.VersionParser.ParsedVersion
import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.schema.{GroupType, Type}
Expand Down Expand Up @@ -54,7 +55,8 @@ class ShimVectorizedColumnReader(
convertTz: ZoneId,
datetimeRebaseMode: String, // always LegacyBehaviorPolicy.CORRECTED
int96RebaseMode: String, // always LegacyBehaviorPolicy.EXCEPTION
int96CDPHive3Compatibility: Boolean
int96CDPHive3Compatibility: Boolean,
writerVersion: ParsedVersion
) extends VectorizedColumnReader(
columns.get(index),
types.get(index).getLogicalTypeAnnotation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors

object RapidsErrorUtils {
Expand All @@ -28,9 +29,12 @@ object RapidsErrorUtils {
}
}

def mapKeyNotExistError(key: String, isElementAtF: Boolean = false): NoSuchElementException = {
def mapKeyNotExistError(
key: String,
isElementAtFunction: Boolean,
origin: Origin): NoSuchElementException = {
// For now, the default argument is false. The caller sets the correct value accordingly.
QueryExecutionErrors.mapKeyNotExistError(key, isElementAtF)
QueryExecutionErrors.mapKeyNotExistError(key, isElementAtFunction, origin.context)
}

def sqlArrayIndexNotStartAtOneError(): ArrayIndexOutOfBoundsException = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet.rapids.shims

import java.time.ZoneId
import java.util.TimeZone

import org.apache.parquet.VersionParser.ParsedVersion
import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.column.page.PageReadStore
import org.apache.parquet.schema.{GroupType, Type}

import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.execution.datasources.parquet.{ParentContainerUpdater, ParquetRowConverter, ParquetToSparkSchemaConverter, VectorizedColumnReader}
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.types.StructType

class ShimParquetRowConverter(
schemaConverter: ParquetToSparkSchemaConverter,
parquetType: GroupType,
catalystType: StructType,
convertTz: Option[ZoneId],
datetimeRebaseMode: LegacyBehaviorPolicy.Value, // always LegacyBehaviorPolicy.CORRECTED
int96RebaseMode: LegacyBehaviorPolicy.Value, // always LegacyBehaviorPolicy.EXCEPTION
int96CDPHive3Compatibility: Boolean,
updater: ParentContainerUpdater
) extends ParquetRowConverter(
schemaConverter,
parquetType,
catalystType,
convertTz,
RebaseSpec(datetimeRebaseMode), // no need to rebase, so set originTimeZone as default
RebaseSpec(int96RebaseMode), // no need to rebase, so set originTimeZone as default
updater)

class ShimVectorizedColumnReader(
index: Int,
columns: java.util.List[ColumnDescriptor],
types: java.util.List[Type],
pageReadStore: PageReadStore,
convertTz: ZoneId,
datetimeRebaseMode: String, // always LegacyBehaviorPolicy.CORRECTED
int96RebaseMode: String, // always LegacyBehaviorPolicy.EXCEPTION
int96CDPHive3Compatibility: Boolean,
writerVersion: ParsedVersion
) extends VectorizedColumnReader(
columns.get(index),
true,
pageReadStore,
convertTz,
datetimeRebaseMode,
TimeZone.getDefault.getID, // use default zone because of no rebase
int96RebaseMode,
TimeZone.getDefault.getID, // use default zone because of will throw exception if rebase
writerVersion)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import com.nvidia.spark.rapids.shims.{ParquetFieldIdShims, SparkShimImpl}
import org.apache.commons.io.output.ByteArrayOutputStream
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.RecordWriter
import org.apache.parquet.{HadoopReadOptions, ParquetReadOptions}
import org.apache.parquet.{HadoopReadOptions, ParquetReadOptions, VersionParser}
import org.apache.parquet.VersionParser.ParsedVersion
import org.apache.parquet.column.{ColumnDescriptor, ParquetProperties}
import org.apache.parquet.hadoop.{CodecFactory, MemoryManager, ParquetFileReader, ParquetFileWriter, ParquetInputFormat, ParquetOutputFormat, ParquetRecordWriter, ParquetWriter}
import org.apache.parquet.hadoop.ParquetFileWriter.Mode
Expand Down Expand Up @@ -860,8 +861,17 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
var numBatched = 0
var batchIdx = 0
var totalCountLoadedSoFar: Long = 0
val parquetFileReader =
val parquetFileReader = {
ParquetFileReader.open(new ByteArrayInputFile(parquetCachedBatch.buffer), options)
}
val writerVersion: ParsedVersion = try {
VersionParser.parse(parquetFileReader.getFileMetaData.getCreatedBy)
} catch {
case _: Exception =>
// If any problems occur trying to parse the writer version, fallback to sequential reads
// if the column is a delta byte array encoding (due to PARQUET-246).
null
}
val (totalRowCount, columnsRequested, cacheSchemaToReqSchemaMap, missingColumns,
columnsInCache, typesInCache) = {
val parquetToSparkSchemaConverter = new ParquetToSparkSchemaConverter(hadoopConf)
Expand Down Expand Up @@ -954,9 +964,11 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
columnsInCache,
typesInCache,
pages,
null /*convertTz*/ ,
convertTz = null,
LegacyBehaviorPolicy.CORRECTED.toString,
LegacyBehaviorPolicy.EXCEPTION.toString, false)
LegacyBehaviorPolicy.EXCEPTION.toString,
int96CDPHive3Compatibility = false,
writerVersion)
}
}
totalCountLoadedSoFar += pages.getRowCount
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ case class GpuElementAt(left: Expression, right: Expression, failOnError: Boolea
if (!exist.isValid || exist.getBoolean) {
map.getMapValue(key)
} else {
throw RapidsErrorUtils.mapKeyNotExistError(keyS.getValue.toString, true)
throw RapidsErrorUtils.mapKeyNotExistError(keyS.getValue.toString,
isElementAtFunction = true, origin)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ case class GpuGetMapValue(child: Expression, key: Expression, failOnError: Boole
withResource(lhs.getBase.getMapKeyExistence(rhs.getBase)) { keyExistenceColumn =>
withResource(keyExistenceColumn.all) { exist =>
if (exist.isValid && !exist.getBoolean) {
throw RapidsErrorUtils.mapKeyNotExistError(rhs.getValue.toString)
throw RapidsErrorUtils.mapKeyNotExistError(rhs.getValue.toString,
isElementAtFunction = false, origin)
}
}
}
Expand Down