Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Support int96RebaseModeInWrite and `int96RebaseModeInRead' (#3330) #3627

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ def write_read_parquet_cached(spark):
# rapids-spark doesn't support LEGACY read for parquet
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead' : 'CORRECTED',
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInWrite' : 'CORRECTED',
'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED',
'spark.sql.inMemoryColumnarStorage.enableVectorizedReader' : enable_vectorized,
'spark.sql.parquet.outputTimestampType': ts_write}

Expand Down
1 change: 0 additions & 1 deletion integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,6 @@ def do_it(spark):


@pytest.mark.parametrize('data_gen', _no_overflow_ansi_gens, ids=idfn)
@ignore_order(local=True)
def test_no_fallback_when_ansi_enabled(data_gen):
def do_it(spark):
df = gen_df(spark, [('a', data_gen), ('b', data_gen)], length=100)
Expand Down
2 changes: 0 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader_confs,
conf=rebase_write_corrected_conf)
all_confs = copy_and_update(reader_confs, {
'spark.sql.sources.useV1SourceList': v1_enabled_list,
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'})
# once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround
# for nested timestamp/date support
Expand Down
65 changes: 9 additions & 56 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pyspark.sql.functions as f
import pyspark.sql.utils
import random
from spark_session import is_before_spark_311

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
Expand All @@ -42,11 +41,6 @@ def limited_timestamp(nullable=True):
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc),
nullable=nullable)

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
def limited_int96():
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))

parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
# we are limiting TimestampGen to avoid overflowing the INT96 value
Expand Down Expand Up @@ -220,44 +214,24 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, ts_type, spark_tmp_t
data_path,
conf=all_confs)

def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write):
def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, ts_rebase, ts_write):
spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write)
spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', datetime_rebase)
spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', int96_rebase) # for spark 310
spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', ts_rebase)
spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', ts_rebase) # for spark 310
with pytest.raises(Exception) as e_info:
df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get())
assert e_info.match(r".*SparkUpgradeException.*")

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('rebase', ["CORRECTED","EXCEPTION"])
def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, spark_tmp_table_factory, rebase):
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))),
('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('ts_rebase', ['EXCEPTION'])
def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, ts_rebase, spark_tmp_table_factory):
ts_write, gen = ts_write_data_gen
data_path = spark_tmp_path + '/PARQUET_DATA'
int96_rebase = "EXCEPTION" if (ts_write == "INT96") else rebase
date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS") else rebase
if is_before_spark_311() and ts_write == 'INT96':
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write}
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase})
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=all_confs)
else:
with_gpu_session(
lambda spark : writeParquetUpgradeCatchException(spark,
unary_op_df(spark, gen),
data_path,
spark_tmp_table_factory,
int96_rebase, date_time_rebase, ts_write))
with_cpu_session(
lambda spark: writeParquetUpgradeCatchException(spark,
unary_op_df(spark, gen), data_path,
spark_tmp_table_factory,
int96_rebase, date_time_rebase, ts_write))
with_gpu_session(
lambda spark : writeParquetUpgradeCatchException(spark, unary_op_df(spark, gen), data_path, spark_tmp_table_factory, ts_rebase, ts_write))

def writeParquetNoOverwriteCatchException(spark, df, data_path, table_name):
with pytest.raises(Exception) as e_info:
Expand Down Expand Up @@ -345,27 +319,6 @@ def generate_map_with_empty_validity(spark, path):
lambda spark, path: spark.read.parquet(path),
data_path)

@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('date_time_rebase_write', ["CORRECTED"])
@pytest.mark.parametrize('date_time_rebase_read', ["EXCEPTION", "CORRECTED"])
@pytest.mark.parametrize('int96_rebase_write', ["CORRECTED"])
@pytest.mark.parametrize('int96_rebase_read', ["EXCEPTION", "CORRECTED"])
def test_roundtrip_with_rebase_values(spark_tmp_path, ts_write_data_gen, date_time_rebase_read,
date_time_rebase_write, int96_rebase_read, int96_rebase_write):
ts_write, gen = ts_write_data_gen
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write}
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase_write,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase_write})
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInRead': date_time_rebase_read,
'spark.sql.legacy.parquet.int96RebaseModeInRead': int96_rebase_read})

assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=all_confs)

@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/3476')
@pytest.mark.allow_non_gpu("DataWritingCommandExec", "HiveTableScanExec")
@pytest.mark.parametrize('allow_non_empty', [True, False])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,6 @@ class Spark311Shims extends SparkBaseShims {
classOf[RapidsShuffleManager].getCanonicalName
}

override def int96ParquetRebaseRead(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ)
}

override def int96ParquetRebaseWrite(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE)
}

override def int96ParquetRebaseReadKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key
}

override def int96ParquetRebaseWriteKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key
}

override def hasCastFloatTimestampUpcast: Boolean = false

override def getParquetFilters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,6 @@ class Spark311CDHShims extends SparkBaseShims {
sessionCatalog.createTable(newTable, ignoreIfExists = false, validateLocation = false)
}

override def int96ParquetRebaseRead(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ)
}

override def int96ParquetRebaseWrite(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE)
}

override def int96ParquetRebaseReadKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key
}

override def int96ParquetRebaseWriteKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key
}

override def hasCastFloatTimestampUpcast: Boolean = false

override def getParquetFilters(
Expand Down
Loading