Skip to content

Commit

Permalink
addressed review comments and added test
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri committed Aug 23, 2021
1 parent 039f94e commit 7b04994
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
33 changes: 31 additions & 2 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from marks import *
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_gpu_session
import pyspark.sql.functions as f
import random

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
Expand All @@ -34,18 +35,30 @@
writer_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'}


def limited_timestamp(nullable=True):
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc), nullable=nullable)

parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
# we are limiting TimestampGen to avoid overflowing the INT96 value
# see https://github.com/rapidsai/cudf/issues/8070
TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))]
limited_timestamp()]

parquet_basic_map_gens = [MapGen(f(nullable=False), f()) for f in [BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DateGen, limited_timestamp]] + [simple_string_to_string_map_gen]

parquet_struct_gen = [StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(parquet_basic_gen)]),
StructGen([['child0', StructGen([[ 'child1', byte_gen]])]]),
StructGen([['child0', MapGen(StringGen(nullable=False), StringGen())], ['child1', IntegerGen()]])]

parquet_array_gen = [ArrayGen(sub_gen, max_length=10) for sub_gen in parquet_basic_gen + parquet_struct_gen] + \
[ArrayGen(ArrayGen(sub_gen, max_length=10), max_length=10) for sub_gen in parquet_basic_gen + parquet_struct_gen]

parquet_map_gens = map_gens_sample + [MapGen(StructGen([['child0', StringGen()], ['child1', StringGen()]], nullable=False), FloatGen()), MapGen(StructGen([['child0', StringGen(nullable=True)]], nullable=False), StringGen())]
parquet_map_gens_sample = parquet_basic_map_gens + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]

parquet_map_gens = parquet_map_gens_sample + [MapGen(StructGen([['child0', StringGen()], ['child1', StringGen()]], nullable=False), FloatGen()), MapGen(StructGen([['child0', StringGen(nullable=True)]], nullable=False), StringGen())]
parquet_write_gens_list = [parquet_basic_gen + parquet_struct_gen + parquet_array_gen + parquet_decimal_gens + parquet_map_gens]
parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']

Expand Down Expand Up @@ -278,3 +291,19 @@ def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory):
lambda spark, path: spark.read.parquet(path),
data_path,
'DataWritingCommandExec')

def test_write_map_nullable(spark_tmp_path):
data_path = spark_tmp_path + '/PARQUET_DATA'

def generate_map_with_empty_validity(spark, path):
gen_data = StructGen([['number', IntegerGen()], ['word', LongGen()]], nullable=False)
gen_df(spark, gen_data)
df = gen_df(spark, gen_data)
df_noNulls = df.filter("number is not null")
df_map = df_noNulls.withColumn("map", f.create_map(["number", "word"])).drop("number").drop("word")
df_map.coalesce(1).write.parquet(path)

assert_gpu_and_cpu_writes_are_equal_collect(
generate_map_with_empty_validity,
lambda spark, path: spark.read.parquet(path),
data_path)
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,16 @@ object GpuParquetFileFormat {
case s: StructType =>
builder.withStructColumn(
parquetWriterOptionsFromSchema(
// we are setting this to nullable, in case the parent is a Map's key and wants to
// set this to false
structBuilder(name, nullable),
s,
writeInt96).build())
case a: ArrayType =>
builder.withListColumn(
parquetWriterOptionsFromField(
// we are setting this to nullable, in case the parent is a Map's key and wants to
// set this to false
listBuilder(name, nullable),
a.elementType,
name,
Expand Down

0 comments on commit 7b04994

Please sign in to comment.