Skip to content

Commit

Permalink
Backport the fix to Spark 3.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Kontinuation committed Oct 30, 2024
1 parent d182a04 commit be98052
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter])
hadoopConf.setBoolean(
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
sparkSession.sessionState.conf.isParquetINT96AsTimestamp)
hadoopConf.setBoolean(
SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key,
sparkSession.sessionState.conf.parquetInferTimestampNTZEnabled)

val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.column.Dictionary
import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter}
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
import org.apache.parquet.schema.OriginalType.LIST
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.{GroupType, OriginalType, Type}
Expand Down Expand Up @@ -312,6 +314,25 @@ private[parquet] class GeoParquetRowConverter(
}
}

case TimestampNTZType
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation]
.getUnit == TimeUnit.MICROS =>
new ParquetPrimitiveConverter(updater)

case TimestampNTZType
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation]
.getUnit == TimeUnit.MILLIS =>
new ParquetPrimitiveConverter(updater) {
override def addLong(value: Long): Unit = {
val micros = DateTimeUtils.millisToMicros(value)
updater.setLong(micros)
}
}

case DateType =>
new ParquetPrimitiveConverter(updater) {
override def addInt(value: Int): Unit = {
Expand Down Expand Up @@ -379,6 +400,17 @@ private[parquet] class GeoParquetRowConverter(
}
}

// Only INT64 column with Timestamp logical annotation `isAdjustedToUTC=false`
// can be read as Spark's TimestampNTZ type. This is to avoid mistakes in reading the timestamp
// values.
private def canReadAsTimestampNTZ(parquetType: Type): Boolean =
schemaConverter.isTimestampNTZEnabled() &&
parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 &&
parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] &&
!parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation]
.isAdjustedToUTC

/**
* Parquet converter for strings. A dictionary is used to minimize string decoding cost.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ import org.apache.spark.sql.types._
* Whether unannotated BINARY fields should be assumed to be Spark SQL [[StringType]] fields.
* @param assumeInt96IsTimestamp
* Whether unannotated INT96 fields should be assumed to be Spark SQL [[TimestampType]] fields.
* @param inferTimestampNTZ
* Whether TimestampNTZType type is enabled.
* @param parameters
* Options for reading GeoParquet files.
*/
class GeoParquetToSparkSchemaConverter(
keyValueMetaData: java.util.Map[String, String],
assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
parameters: Map[String, String]) {

private val geoParquetMetaData: GeoParquetMetaData =
Expand All @@ -61,6 +64,7 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString = conf.isParquetBinaryAsString,
assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
parameters = parameters)

def this(
Expand All @@ -70,8 +74,16 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
parameters = parameters)

/**
* Returns true if TIMESTAMP_NTZ type is enabled in this ParquetToSparkSchemaConverter.
*/
def isTimestampNTZEnabled(): Boolean = {
inferTimestampNTZ
}

/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ class GeoParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
recordConsumer.addLong(millis)
}

case TimestampNTZType =>
// For TimestampNTZType column, Spark always output as INT64 with Timestamp annotation in
// MICROS time unit.
(row: SpecializedGetters, ordinal: Int) => recordConsumer.addLong(row.getLong(ordinal))

case BinaryType =>
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.st_constructors.{ST_Point, ST_PolygonFromEnvelope}
import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampNTZType}
import org.json4s.jackson.parseJson
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.io.WKTReader
import org.scalatest.BeforeAndAfterAll

import java.io.File
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.util.Collections
import java.util.concurrent.atomic.AtomicLong
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -732,6 +732,32 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll {
}
}

describe("Spark types tests") {
it("should support timestamp_ntz") {
// Write geoparquet files with a TimestampNTZ column
val schema = StructType(
Seq(
StructField("id", IntegerType, nullable = false),
StructField("timestamp_ntz", TimestampNTZType, nullable = false)))
val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
val data = Seq(
Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)),
Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter)))
val df = sparkSession
.createDataFrame(sparkSession.sparkContext.parallelize(data), schema)
.withColumn("geom", expr("ST_Point(id, id)"))
df.write.format("geoparquet").save(geoparquetoutputlocation)

// Read it back
val df2 =
sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
assert(df2.schema.fields(1).dataType == TimestampNTZType)
val data1 = df.sort(col("id")).collect()
val data2 = df2.collect()
assert(data1 sameElements data2)
}
}

def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => Unit): Unit = {
val parquetFiles = new File(path).listFiles().filter(_.getName.endsWith(".parquet"))
parquetFiles.foreach { filePath =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,17 @@ private[parquet] class GeoParquetRowConverter(
}

case TimestampNTZType
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MICROS =>
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation]
.getUnit == TimeUnit.MICROS =>
new ParquetPrimitiveConverter(updater)

case TimestampNTZType
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MILLIS =>
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation]
.getUnit == TimeUnit.MILLIS =>
new ParquetPrimitiveConverter(updater) {
override def addLong(value: Long): Unit = {
val micros = DateTimeUtils.millisToMicros(value)
Expand Down Expand Up @@ -409,7 +411,8 @@ private[parquet] class GeoParquetRowConverter(
parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 &&
parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] &&
!parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation].isAdjustedToUTC
.asInstanceOf[TimestampLogicalTypeAnnotation]
.isAdjustedToUTC

/**
* Parquet converter for strings. A dictionary is used to minimize string decoding cost.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,23 +738,22 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll {
describe("Spark types tests") {
it("should support timestamp_ntz") {
// Write geoparquet files with a TimestampNTZ column
val schema = StructType(Seq(
StructField("id", IntegerType, nullable = false),
StructField("timestamp_ntz", TimestampNTZType, nullable = false)
))
val schema = StructType(
Seq(
StructField("id", IntegerType, nullable = false),
StructField("timestamp_ntz", TimestampNTZType, nullable = false)))
val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
val data = Seq(
Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)),
Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter))
)
val df = sparkSession.createDataFrame(
sparkSession.sparkContext.parallelize(data),
schema
).withColumn("geom", expr("ST_Point(id, id)"))
Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter)))
val df = sparkSession
.createDataFrame(sparkSession.sparkContext.parallelize(data), schema)
.withColumn("geom", expr("ST_Point(id, id)"))
df.write.format("geoparquet").save(geoparquetoutputlocation)

// Read it back
val df2 = sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
val df2 =
sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
assert(df2.schema.fields(1).dataType == TimestampNTZType)
val data1 = df.sort(col("id")).collect()
val data2 = df2.collect()
Expand Down

0 comments on commit be98052

Please sign in to comment.