Skip to content

Commit

Permalink
Support int96RebaseModeInWrite and int96RebaseModeInRead (NVIDIA#…
Browse files Browse the repository at this point in the history
…3330)

* support int96 rebase mode

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed unnecessary changes to the SparkBaseShim

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed unnecessary extra line

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* split the date and time exception checks

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* skip test if Spark version <3.1.1

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added more tests and addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed the failing test

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added method to return if running before 311

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* changed the API to return existence of a separate INT96 rebase conf

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* updated the 32x shim to override the correct method

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added DB support

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* The default value for int96 rebase in databricks is legacy, explicitly set the value

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* skip ANSI test for hash aggregate until we root cause the failure

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Adding resolution to the failed test

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Revert "Adding resolution to the failed test"

This reverts commit 14658b4.

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Adding resolution to the failed test

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Update integration_tests/src/main/python/hash_aggregate_test.py

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
Co-authored-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
3 people authored Sep 23, 2021
1 parent 2bf1d4d commit fc40c00
Show file tree
Hide file tree
Showing 21 changed files with 1,261 additions and 82 deletions.
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def write_read_parquet_cached(spark):
# rapids-spark doesn't support LEGACY read for parquet
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead' : 'CORRECTED',
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInWrite' : 'CORRECTED',
'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED',
'spark.sql.inMemoryColumnarStorage.enableVectorizedReader' : enable_vectorized,
'spark.sql.parquet.outputTimestampType': ts_write}

Expand Down
1 change: 1 addition & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,7 @@ def do_it(spark):


@pytest.mark.parametrize('data_gen', _no_overflow_ansi_gens, ids=idfn)
@ignore_order(local=True)
def test_no_fallback_when_ansi_enabled(data_gen):
def do_it(spark):
df = gen_df(spark, [('a', data_gen), ('b', data_gen)], length=100)
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader_confs,
conf=rebase_write_corrected_conf)
all_confs = copy_and_update(reader_confs, {
'spark.sql.sources.useV1SourceList': v1_enabled_list,
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'})
# once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround
# for nested timestamp/date support
Expand Down
65 changes: 56 additions & 9 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pyspark.sql.functions as f
import pyspark.sql.utils
import random
from spark_session import is_before_spark_311

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
Expand All @@ -41,6 +42,11 @@ def limited_timestamp(nullable=True):
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc),
nullable=nullable)

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
def limited_int96():
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))

parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
# we are limiting TimestampGen to avoid overflowing the INT96 value
Expand Down Expand Up @@ -214,24 +220,44 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, ts_type, spark_tmp_t
data_path,
conf=all_confs)

def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, ts_rebase, ts_write):
def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write):
spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write)
spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', ts_rebase)
spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', ts_rebase) # for spark 310
spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', datetime_rebase)
spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', int96_rebase) # for spark 310
with pytest.raises(Exception) as e_info:
df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get())
assert e_info.match(r".*SparkUpgradeException.*")

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))),
('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('ts_rebase', ['EXCEPTION'])
def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, ts_rebase, spark_tmp_table_factory):
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('rebase', ["CORRECTED","EXCEPTION"])
def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, spark_tmp_table_factory, rebase):
ts_write, gen = ts_write_data_gen
data_path = spark_tmp_path + '/PARQUET_DATA'
with_gpu_session(
lambda spark : writeParquetUpgradeCatchException(spark, unary_op_df(spark, gen), data_path, spark_tmp_table_factory, ts_rebase, ts_write))
int96_rebase = "EXCEPTION" if (ts_write == "INT96") else rebase
date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS") else rebase
if is_before_spark_311() and ts_write == 'INT96':
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write}
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase})
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=all_confs)
else:
with_gpu_session(
lambda spark : writeParquetUpgradeCatchException(spark,
unary_op_df(spark, gen),
data_path,
spark_tmp_table_factory,
int96_rebase, date_time_rebase, ts_write))
with_cpu_session(
lambda spark: writeParquetUpgradeCatchException(spark,
unary_op_df(spark, gen), data_path,
spark_tmp_table_factory,
int96_rebase, date_time_rebase, ts_write))

def writeParquetNoOverwriteCatchException(spark, df, data_path, table_name):
with pytest.raises(Exception) as e_info:
Expand Down Expand Up @@ -319,6 +345,27 @@ def generate_map_with_empty_validity(spark, path):
lambda spark, path: spark.read.parquet(path),
data_path)

@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('date_time_rebase_write', ["CORRECTED"])
@pytest.mark.parametrize('date_time_rebase_read', ["EXCEPTION", "CORRECTED"])
@pytest.mark.parametrize('int96_rebase_write', ["CORRECTED"])
@pytest.mark.parametrize('int96_rebase_read', ["EXCEPTION", "CORRECTED"])
def test_roundtrip_with_rebase_values(spark_tmp_path, ts_write_data_gen, date_time_rebase_read,
date_time_rebase_write, int96_rebase_read, int96_rebase_write):
ts_write, gen = ts_write_data_gen
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write}
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase_write,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase_write})
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInRead': date_time_rebase_read,
'spark.sql.legacy.parquet.int96RebaseModeInRead': int96_rebase_read})

assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=all_confs)

@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/3476')
@pytest.mark.allow_non_gpu("DataWritingCommandExec", "HiveTableScanExec")
@pytest.mark.parametrize('allow_non_empty', [True, False])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ class Spark311Shims extends SparkBaseShims {
classOf[RapidsShuffleManager].getCanonicalName
}

override def int96ParquetRebaseRead(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ)
}

override def int96ParquetRebaseWrite(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE)
}

override def int96ParquetRebaseReadKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key
}

override def int96ParquetRebaseWriteKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key
}

override def hasCastFloatTimestampUpcast: Boolean = false

override def getParquetFilters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ class Spark311CDHShims extends SparkBaseShims {
sessionCatalog.createTable(newTable, ignoreIfExists = false, validateLocation = false)
}

override def int96ParquetRebaseRead(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ)
}

override def int96ParquetRebaseWrite(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE)
}

override def int96ParquetRebaseReadKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key
}

override def int96ParquetRebaseWriteKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key
}

override def hasCastFloatTimestampUpcast: Boolean = false

override def getParquetFilters(
Expand Down
Loading

0 comments on commit fc40c00

Please sign in to comment.