Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ANSI intervals to/from Parquet #4810

Merged
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
# limitations under the License.

from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit, spark_jvm
from datetime import date, datetime
from datetime import date, datetime, timedelta
from decimal import Decimal
import math
from pyspark.sql import Row
Expand Down Expand Up @@ -92,6 +92,9 @@ def _assert_equal(cpu, gpu, float_check, path):
assert cpu == gpu, "GPU and CPU decimal values are different at {}".format(path)
elif isinstance(cpu, bytearray):
assert cpu == gpu, "GPU and CPU bytearray values are different at {}".format(path)
elif isinstance(cpu, timedelta):
# Used by interval type DayTimeInterval for Pyspark 3.3.0+
assert cpu == gpu, "GPU and CPU timedelta values are different at {}".format(path)
elif (cpu == None):
assert cpu == gpu, "GPU and CPU are not both null at {}".format(path)
else:
Expand Down
27 changes: 27 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,33 @@ def make_null():
return None
self._start(rand, make_null)

# DayTimeIntervalGen is for Spark 3.3.0+
# DayTimeIntervalType(startField, endField): Represents a day-time interval which is made up of a contiguous subset of the following fields:
# SECOND, seconds within minutes and possibly fractions of a second [0..59.999999],
# MINUTE, minutes within hours [0..59],
# HOUR, hours within days [0..23],
# DAY, days in the range [0..106751991].
# For more details: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
# Note: 106751991/365 = 292471 years which is much bigger than 9999 year, seems something is wrong
class DayTimeIntervalGen(DataGen):
"""Generate DayTimeIntervalType values"""
def __init__(self, max_days = None, nullable=True, special_cases =[timedelta(seconds = 0)]):
super().__init__(DayTimeIntervalType(), nullable=nullable, special_cases=special_cases)
if max_days is None:
self._max_days = 106751991
else:
self._max_days = max_days
def start(self, rand):
self._start(rand,
lambda : timedelta(
microseconds = rand.randint(0, 999999),
seconds = rand.randint(0, 59),
minutes = rand.randint(0, 59),
hours = rand.randint(0, 23),
days = rand.randint(0, self._max_days),
)
)

def skip_if_not_utc():
if (not is_tz_utc()):
skip_unless_precommit_tests('The java system time zone is not set to UTC')
Expand Down
14 changes: 12 additions & 2 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@
from datetime import date, datetime, timezone
from marks import incompat, allow_non_gpu
from pyspark.sql.types import *
from spark_session import with_spark_session, is_before_spark_311
from spark_session import with_spark_session, is_before_spark_311, is_before_spark_330
import pyspark.sql.functions as f

# We only support literal intervals for TimeSub
Expand All @@ -41,6 +41,16 @@ def test_timeadd(data_gen):
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
.selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_timeadd_daytime_column():
revans2 marked this conversation as resolved.
Show resolved Hide resolved
gen_list = [
# timestamp column max year is 1000
('t', TimestampGen(end = datetime(1000, 1, 1, tzinfo=timezone.utc))),
# max days is 8000 year, so added result will not be out of range
('d', DayTimeIntervalGen(max_days = 8000 * 365))]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"))

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
def test_dateaddinterval(data_gen):
days, seconds = data_gen
Expand Down
18 changes: 18 additions & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,21 @@ def test_parquet_read_field_id(spark_tmp_path):
lambda spark: spark.read.schema(readSchema).parquet(data_path),
'FileSourceScanExec',
{"spark.sql.parquet.fieldId.read.enabled": "true"}) # default is false

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_parquet_read_daytime_interval_cpu_file(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('_c1', DayTimeIntervalGen())]
# write DayTimeInterval with CPU
with_cpu_session(lambda spark :gen_df(spark, gen_list).coalesce(1).write.mode("overwrite").parquet(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_parquet_read_daytime_interval_gpu_file(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'
gen_list = [('_c1', DayTimeIntervalGen())]
# write DayTimeInterval with GPU
with_gpu_session(lambda spark :gen_df(spark, gen_list).coalesce(1).write.mode("overwrite").parquet(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path))
11 changes: 11 additions & 0 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,14 @@ def test_parquet_write_field_id(spark_tmp_path):
data_path,
'DataWritingCommandExec',
conf = {"spark.sql.parquet.fieldId.write.enabled" : "true"}) # default is true

@pytest.mark.order(1) # at the head of xdist worker queue if pytest-order is installed
@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_write_daytime_interval(spark_tmp_path):
gen_list = [('_c1', DayTimeIntervalGen())]
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=writer_confs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids.shims.v2

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter

import org.apache.spark.sql.types.DataType

object GpuTypeShims {

/**
* If Shim supports the data type for row to column converter
* @param otherType the data type that should be checked in the Shim
* @return true if Shim support the otherType, false otherwise.
*/
def hasConverterForType(otherType: DataType) : Boolean = false

/**
* Get the TypeConverter of the data type for this Shim
* Note should first calling hasConverterForType
* @param t the data type
* @param nullable is nullable
* @return the row to column convert for the data type
*/
def getConverterForType(t: DataType, nullable: Boolean): TypeConverter = {
throw new RuntimeException(s"No converter is found for type $t.")
}

/**
* Get the cuDF type for the Spark data type
* @param t the Spark data type
* @return the cuDF type if the Shim supports
*/
def toRapidsOrNull(t: DataType): DType = null
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.v2

import java.util.concurrent.TimeUnit

import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}
import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.ShimBinaryExpression
Expand Down Expand Up @@ -59,48 +59,59 @@ case class GpuTimeAdd(start: Expression,
override def columnarEval(batch: ColumnarBatch): Any = {
withResourceIfAllowed(left.columnarEval(batch)) { lhs =>
withResourceIfAllowed(right.columnarEval(batch)) { rhs =>
// lhs is start, rhs is interval
(lhs, rhs) match {
case (l: GpuColumnVector, intvlS: GpuScalar) =>
val interval = intvlS.dataType match {
case (l: GpuColumnVector, intervalS: GpuScalar) =>
// get long type interval
val interval = intervalS.dataType match {
case CalendarIntervalType =>
// Scalar does not support 'CalendarInterval' now, so use
// the Scala value instead.
// Skip the null check because it wll be detected by the following calls.
val intvl = intvlS.getValue.asInstanceOf[CalendarInterval]
if (intvl.months != 0) {
val calendarI = intervalS.getValue.asInstanceOf[CalendarInterval]
if (calendarI.months != 0) {
throw new UnsupportedOperationException("Months aren't supported at the moment")
}
intvl.days * microSecondsInOneDay + intvl.microseconds
calendarI.days * microSecondsInOneDay + calendarI.microseconds
case _: DayTimeIntervalType =>
// Scalar does not support 'DayTimeIntervalType' now, so use
// the Scala value instead.
intvlS.getValue.asInstanceOf[Long]
intervalS.getValue.asInstanceOf[Long]
case _ =>
throw new UnsupportedOperationException("GpuTimeAdd unsupported data type: " +
intvlS.dataType)
throw new UnsupportedOperationException(
"GpuTimeAdd unsupported data type: " + intervalS.dataType)
}

// add interval
if (interval != 0) {
withResource(Scalar.fromLong(interval)) { us_s =>
withResource(l.getBase.bitCastTo(DType.INT64)) { us =>
withResource(intervalMath(us_s, us)) { longResult =>
GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS),
dataType)
}
}
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { d =>
GpuColumnVector.from(timestampAddDuration(l.getBase, d), dataType)
}
} else {
l.incRefCount()
}
case (l: GpuColumnVector, r: GpuColumnVector) =>
(l.dataType(), r.dataType) match {
case (_: TimestampType, _: DayTimeIntervalType) =>
// DayTimeIntervalType is stored as long
// bitCastTo is similar to reinterpret_cast, it's fast, the time can be ignored.
withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration =>
GpuColumnVector.from(timestampAddDuration(l.getBase, duration), dataType)
}
case _ =>
throw new UnsupportedOperationException(
"GpuTimeAdd takes column and interval as an argument only")
}
case _ =>
throw new UnsupportedOperationException("GpuTimeAdd takes column and interval as an " +
"argument only")
throw new UnsupportedOperationException(
"GpuTimeAdd takes column and interval as an argument only")
}
}
}
}

private def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
us.add(us_s)
private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = {
// Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion,
// and currently BinaryOperable.implicitConversion return Long
// Directly specify the return type is TIMESTAMP_MICROSECONDS
cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids.shims.v2

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.{LongConverter, NotNullLongConverter, TypeConverter}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType}

/**
* Spark stores ANSI YearMonthIntervalType as int32 and ANSI DayTimeIntervalType as int64
* internally when computing.
* See the comments of YearMonthIntervalType, below is copied from Spark
* Internally, values of year-month intervals are stored in `Int` values as amount of months
* that are calculated by the formula:
* -/+ (12 * YEAR + MONTH)
* See the comments of DayTimeIntervalType, below is copied from Spark
* Internally, values of day-time intervals are stored in `Long` values as amount of time in terms
* of microseconds that are calculated by the formula:
* -/+ (24*60*60 * DAY + 60*60 * HOUR + 60 * MINUTE + SECOND) * 1000000
*
* Spark also stores ANSI intervals as int32 and int64 in Parquet file:
* - year-month intervals as `INT32`
* - day-time intervals as `INT64`
* To load the values as intervals back, Spark puts the info about interval types
* to the extra key `org.apache.spark.sql.parquet.row.metadata`:
* $ java -jar parquet-tools-1.12.0.jar meta ./part-...-c000.snappy.parquet
* creator: parquet-mr version 1.12.1 (build 2a5c06c58fa987f85aa22170be14d927d5ff6e7d)
* extra: org.apache.spark.version = 3.3.0
* extra: org.apache.spark.sql.parquet.row.metadata =
* {"type":"struct","fields":[...,
* {"name":"i","type":"interval year to month","nullable":false,"metadata":{}}]}
* file schema: spark_schema
* --------------------------------------------------------------------------------
* ...
* i: REQUIRED INT32 R:0 D:0
*
* For details See https://issues.apache.org/jira/browse/SPARK-36825
*/
object GpuTypeShims {

/**
* If Shim supports the data type for row to column converter
* @param otherType the data type that should be checked in the Shim
* @return true if Shim support the otherType, false otherwise.
*/
def hasConverterForType(otherType: DataType) : Boolean = {
otherType match {
case DayTimeIntervalType(_, _) => true
case _ => false
}
}

/**
* Get the TypeConverter of the data type for this Shim
* Note should first calling hasConverterForType
* @param t the data type
* @param nullable is nullable
* @return the row to column convert for the data type
*/
def getConverterForType(t: DataType, nullable: Boolean): TypeConverter = {
(t, nullable) match {
case (DayTimeIntervalType(_, _), true) => LongConverter
case (DayTimeIntervalType(_, _), false) => NotNullLongConverter
case _ => throw new RuntimeException(s"No converter is found for type $t.")
}
}

/**
* Get the cuDF type for the Spark data type
* @param t the Spark data type
* @return the cuDF type if the Shim supports
*/
def toRapidsOrNull(t: DataType): DType = {
t match {
case _: DayTimeIntervalType =>
// use int64 as Spark does
DType.INT64
revans2 marked this conversation as resolved.
Show resolved Hide resolved
case _ =>
null
}
}
}
Loading