Skip to content

Commit

Permalink
Fix timestamp_nz for geoparquet format
Browse files Browse the repository at this point in the history
  • Loading branch information
Kontinuation committed Oct 30, 2024
1 parent 21c3930 commit d182a04
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.column.Dictionary
import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter}
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
import org.apache.parquet.schema.OriginalType.LIST
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.{GroupType, OriginalType, Type}
Expand Down Expand Up @@ -315,6 +317,23 @@ private[parquet] class GeoParquetRowConverter(
}
}

case TimestampNTZType
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MICROS =>
new ParquetPrimitiveConverter(updater)

case TimestampNTZType
if canReadAsTimestampNTZ(parquetType) &&
parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MILLIS =>
new ParquetPrimitiveConverter(updater) {
override def addLong(value: Long): Unit = {
val micros = DateTimeUtils.millisToMicros(value)
updater.setLong(micros)
}
}

case DateType =>
new ParquetPrimitiveConverter(updater) {
override def addInt(value: Int): Unit = {
Expand Down Expand Up @@ -382,6 +401,16 @@ private[parquet] class GeoParquetRowConverter(
}
}

// Only INT64 column with Timestamp logical annotation `isAdjustedToUTC=false`
// can be read as Spark's TimestampNTZ type. This is to avoid mistakes in reading the timestamp
// values.
private def canReadAsTimestampNTZ(parquetType: Type): Boolean =
schemaConverter.isTimestampNTZEnabled() &&
parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 &&
parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] &&
!parquetType.getLogicalTypeAnnotation
.asInstanceOf[TimestampLogicalTypeAnnotation].isAdjustedToUTC

/**
* Parquet converter for strings. A dictionary is used to minimize string decoding cost.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ import org.apache.spark.sql.types._
* Whether unannotated BINARY fields should be assumed to be Spark SQL [[StringType]] fields.
* @param assumeInt96IsTimestamp
* Whether unannotated INT96 fields should be assumed to be Spark SQL [[TimestampType]] fields.
* @param inferTimestampNTZ
* Whether TimestampNTZType type is enabled.
* @param parameters
* Options for reading GeoParquet files.
*/
class GeoParquetToSparkSchemaConverter(
keyValueMetaData: java.util.Map[String, String],
assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
parameters: Map[String, String]) {

private val geoParquetMetaData: GeoParquetMetaData =
Expand All @@ -61,6 +64,7 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString = conf.isParquetBinaryAsString,
assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
parameters = parameters)

def this(
Expand All @@ -70,8 +74,16 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
parameters = parameters)

/**
* Returns true if TIMESTAMP_NTZ type is enabled in this ParquetToSparkSchemaConverter.
*/
def isTimestampNTZEnabled(): Boolean = {
inferTimestampNTZ
}

/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ class GeoParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
recordConsumer.addLong(millis)
}

case TimestampNTZType =>
// For TimestampNTZType column, Spark always output as INT64 with Timestamp annotation in
// MICROS time unit.
(row: SpecializedGetters, ordinal: Int) => recordConsumer.addLong(row.getLong(ordinal))

case BinaryType =>
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.TimestampNTZType
import org.json4s.jackson.parseJson
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.io.WKTReader
Expand All @@ -43,6 +44,8 @@ import org.scalatest.BeforeAndAfterAll
import java.io.File
import java.util.Collections
import java.util.concurrent.atomic.AtomicLong
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import scala.collection.JavaConverters._

class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll {
Expand Down Expand Up @@ -732,6 +735,33 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll {
}
}

describe("Spark types tests") {
it("should support timestamp_ntz") {
// Write geoparquet files with a TimestampNTZ column
val schema = StructType(Seq(
StructField("id", IntegerType, nullable = false),
StructField("timestamp_ntz", TimestampNTZType, nullable = false)
))
val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
val data = Seq(
Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)),
Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter))
)
val df = sparkSession.createDataFrame(
sparkSession.sparkContext.parallelize(data),
schema
).withColumn("geom", expr("ST_Point(id, id)"))
df.write.format("geoparquet").save(geoparquetoutputlocation)

// Read it back
val df2 = sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
assert(df2.schema.fields(1).dataType == TimestampNTZType)
val data1 = df.sort(col("id")).collect()
val data2 = df2.collect()
assert(data1 sameElements data2)
}
}

def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => Unit): Unit = {
val parquetFiles = new File(path).listFiles().filter(_.getName.endsWith(".parquet"))
parquetFiles.foreach { filePath =>
Expand Down

0 comments on commit d182a04

Please sign in to comment.