Skip to content

Commit

Permalink
Added support for Array[Struct] (#3619)
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Sep 23, 2021
1 parent 2575d4e commit 64039e6
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 20 deletions.
19 changes: 16 additions & 3 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def op_df(spark, length=2048, seed=0):
assert_gpu_and_cpu_are_equal_collect(op_df, conf = conf)

@pytest.mark.parametrize('data_gen', [all_basic_struct_gen, StructGen([['child0', StructGen([['child1', byte_gen]])]]),
decimal_struct_gen] + all_gen, ids=idfn)
ArrayGen(
StructGen([['child0', StringGen()],
['child1',
StructGen([['child0', IntegerGen()]])]])),
decimal_struct_gen] + single_level_array_gens_no_null + all_gen, ids=idfn)
@pytest.mark.parametrize('enable_vectorized_conf', enable_vectorized_confs, ids=idfn)
@allow_non_gpu('CollectLimitExec')
def test_cache_partial_load(data_gen, enable_vectorized_conf):
Expand Down Expand Up @@ -182,13 +186,18 @@ def n_fold(spark):
# This test doesn't allow negative scale for Decimals as ` df.write.mode('overwrite').parquet(data_path)`
# writes parquet which doesn't allow negative decimals
@pytest.mark.parametrize('data_gen', [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),
ArrayGen(
StructGen([['child0', StringGen()],
['child1',
StructGen([['child0', IntegerGen()]])]])),
pytest.param(FloatGen(special_cases=[FLOAT_MIN, FLOAT_MAX, 0.0, 1.0, -1.0]), marks=[incompat]),
pytest.param(DoubleGen(special_cases=double_special_cases), marks=[incompat]),
BooleanGen(), DateGen(), TimestampGen(), decimal_gen_default, decimal_gen_scale_precision,
decimal_gen_same_scale_precision, decimal_gen_64bit], ids=idfn)
decimal_gen_same_scale_precision, decimal_gen_64bit] + single_level_array_gens_no_null, ids=idfn)
@pytest.mark.parametrize('ts_write', ['TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'])
@pytest.mark.parametrize('enable_vectorized', ['true', 'false'], ids=idfn)
@ignore_order
@allow_non_gpu("SortExec", "ShuffleExchangeExec", "RangePartitioning")
def test_cache_columnar(spark_tmp_path, data_gen, enable_vectorized, ts_write):
data_path_gpu = spark_tmp_path + '/PARQUET_DATA'
def read_parquet_cached(data_path):
Expand All @@ -211,7 +220,11 @@ def write_read_parquet_cached(spark):
assert_gpu_and_cpu_are_equal_collect(read_parquet_cached(data_path_gpu), conf)

@pytest.mark.parametrize('data_gen', [all_basic_struct_gen, StructGen([['child0', StructGen([['child1', byte_gen]])]]),
decimal_struct_gen]+ all_gen, ids=idfn)
decimal_struct_gen,
ArrayGen(
StructGen([['child0', StringGen()],
['child1',
StructGen([['child0', IntegerGen()]])]]))] + single_level_array_gens_no_null + all_gen, ids=idfn)
@pytest.mark.parametrize('enable_vectorized_conf', enable_vectorized_confs, ids=idfn)
def test_cache_cpu_gpu_mixed(data_gen, enable_vectorized_conf):
def func(spark):
Expand Down
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,9 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
decimal_gens = [decimal_gen_neg_scale] + decimal_gens_no_neg

# all of the basic gens
all_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen]
all_basic_gens_no_null = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen]
all_basic_gens = all_basic_gens_no_null + [null_gen]

all_basic_gens_no_nan = [byte_gen, short_gen, int_gen, long_gen, FloatGen(no_nans=True), DoubleGen(no_nans=True),
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen]
Expand Down Expand Up @@ -891,6 +892,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):

single_level_array_gens = [ArrayGen(sub_gen) for sub_gen in all_basic_gens + decimal_gens]

single_level_array_gens_no_null = [ArrayGen(sub_gen) for sub_gen in all_basic_gens_no_null + decimal_gens_no_neg]

single_level_array_gens_no_nan = [ArrayGen(sub_gen) for sub_gen in all_basic_gens_no_nan + decimal_gens]

single_level_array_gens_no_decimal = [ArrayGen(sub_gen) for sub_gen in all_basic_gens]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class Spark311CDHShims extends SparkBaseShims {
super.getExecs ++ Seq(
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested()
.withPsNote(TypeEnum.DECIMAL,
"Negative scales aren't supported at the moment even with " +
"spark.sql.legacy.allowNegativeScaleOfDecimal set to true. This is because Parquet " +
"doesn't support negative scale for decimal values"),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested().withPsNote(TypeEnum.DECIMAL,
"Negative scales aren't supported at the moment even with " +
"spark.sql.legacy.allowNegativeScaleOfDecimal set to true. " +
"This is because Parquet doesn't support negative scale for decimal values"),
TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ class Spark320Shims extends Spark32XShims {
}),
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested(), TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
if (!scan.relation.cacheBuilder.serializer.isInstanceOf[ParquetCachedBatchSerializer]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,7 @@ class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer with Arm {

def isSupportedByCudf(dataType: DataType): Boolean = {
dataType match {
// TODO: when arrays are supported for cudf writes add it here.
// https://github.com/NVIDIA/spark-rapids/issues/2054
case a: ArrayType => isSupportedByCudf(a.elementType)
case s: StructType => s.forall(field => isSupportedByCudf(field.dataType))
case _ => GpuColumnVector.isNonNestedSupportedType(dataType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,12 @@ abstract class SparkBaseShims extends Spark30XShims {
}),
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested(), TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
if (!scan.relation.cacheBuilder.serializer
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,12 @@ abstract class SparkBaseShims extends Spark31XShims {
}),
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested(), TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
if (!scan.relation.cacheBuilder.serializer
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used")
}
}
Expand Down

0 comments on commit 64039e6

Please sign in to comment.