From d182a043d20d62ae4eaa54c30d98d538fdcb44c5 Mon Sep 17 00:00:00 2001 From: Kristin Cowlacijk Date: Wed, 30 Oct 2024 10:34:15 +0800 Subject: [PATCH] Fix timestamp_nz for geoparquet format --- .../parquet/GeoParquetRowConverter.scala | 29 ++++++++++++++++++ .../parquet/GeoParquetSchemaConverter.scala | 12 ++++++++ .../parquet/GeoParquetWriteSupport.scala | 5 ++++ .../apache/sedona/sql/geoparquetIOTests.scala | 30 +++++++++++++++++++ 4 files changed, 76 insertions(+) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala index 07fc77e2c6..97319d9daf 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit +import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation import org.apache.parquet.schema.OriginalType.LIST import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.{GroupType, OriginalType, Type} @@ -315,6 +317,23 @@ private[parquet] class GeoParquetRowConverter( } } + case TimestampNTZType + if canReadAsTimestampNTZ(parquetType) && + parquetType.getLogicalTypeAnnotation + .asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MICROS => + new ParquetPrimitiveConverter(updater) + + case TimestampNTZType + if canReadAsTimestampNTZ(parquetType) && + parquetType.getLogicalTypeAnnotation + .asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MILLIS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + val micros = DateTimeUtils.millisToMicros(value) + updater.setLong(micros) + } + } + case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { @@ -382,6 +401,16 @@ private[parquet] class GeoParquetRowConverter( } } + // Only INT64 column with Timestamp logical annotation `isAdjustedToUTC=false` + // can be read as Spark's TimestampNTZ type. This is to avoid mistakes in reading the timestamp + // values. + private def canReadAsTimestampNTZ(parquetType: Type): Boolean = + schemaConverter.isTimestampNTZEnabled() && + parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 && + parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] && + !parquetType.getLogicalTypeAnnotation + .asInstanceOf[TimestampLogicalTypeAnnotation].isAdjustedToUTC + /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. */ diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala index eab20875a6..10dd9e01d2 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala @@ -42,6 +42,8 @@ import org.apache.spark.sql.types._ * Whether unannotated BINARY fields should be assumed to be Spark SQL [[StringType]] fields. * @param assumeInt96IsTimestamp * Whether unannotated INT96 fields should be assumed to be Spark SQL [[TimestampType]] fields. + * @param inferTimestampNTZ + * Whether TimestampNTZType type is enabled. * @param parameters * Options for reading GeoParquet files. */ @@ -49,6 +51,7 @@ class GeoParquetToSparkSchemaConverter( keyValueMetaData: java.util.Map[String, String], assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get, parameters: Map[String, String]) { private val geoParquetMetaData: GeoParquetMetaData = @@ -61,6 +64,7 @@ class GeoParquetToSparkSchemaConverter( keyValueMetaData = keyValueMetaData, assumeBinaryIsString = conf.isParquetBinaryAsString, assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled, parameters = parameters) def this( @@ -70,8 +74,16 @@ class GeoParquetToSparkSchemaConverter( keyValueMetaData = keyValueMetaData, assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean, parameters = parameters) + /** + * Returns true if TIMESTAMP_NTZ type is enabled in this ParquetToSparkSchemaConverter. + */ + def isTimestampNTZEnabled(): Boolean = { + inferTimestampNTZ + } + /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. */ diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala index fb5c921636..18f9f4f5c2 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala @@ -309,6 +309,11 @@ class GeoParquetWriteSupport extends WriteSupport[InternalRow] with Logging { recordConsumer.addLong(millis) } + case TimestampNTZType => + // For TimestampNTZType column, Spark always output as INT64 with Timestamp annotation in + // MICROS time unit. + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addLong(row.getLong(ordinal)) + case BinaryType => (row: SpecializedGetters, ordinal: Int) => recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal))) diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala index ccfd560c84..a8c16c127b 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.TimestampNTZType import org.json4s.jackson.parseJson import org.locationtech.jts.geom.Geometry import org.locationtech.jts.io.WKTReader @@ -43,6 +44,8 @@ import org.scalatest.BeforeAndAfterAll import java.io.File import java.util.Collections import java.util.concurrent.atomic.AtomicLong +import java.time.LocalDateTime +import java.time.format.DateTimeFormatter import scala.collection.JavaConverters._ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { @@ -732,6 +735,33 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { } } + describe("Spark types tests") { + it("should support timestamp_ntz") { + // Write geoparquet files with a TimestampNTZ column + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("timestamp_ntz", TimestampNTZType, nullable = false) + )) + val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss") + val data = Seq( + Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)), + Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter)) + ) + val df = sparkSession.createDataFrame( + sparkSession.sparkContext.parallelize(data), + schema + ).withColumn("geom", expr("ST_Point(id, id)")) + df.write.format("geoparquet").save(geoparquetoutputlocation) + + // Read it back + val df2 = sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id")) + assert(df2.schema.fields(1).dataType == TimestampNTZType) + val data1 = df.sort(col("id")).collect() + val data2 = df2.collect() + assert(data1 sameElements data2) + } + } + def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => Unit): Unit = { val parquetFiles = new File(path).listFiles().filter(_.getName.endsWith(".parquet")) parquetFiles.foreach { filePath =>