Skip to content

Commit

Permalink
Added support for Struct[Array]
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri committed Sep 23, 2021
1 parent bfc1e2c commit 6b64965
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 17 deletions.
18 changes: 16 additions & 2 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def op_df(spark, length=2048, seed=0):
assert_gpu_and_cpu_are_equal_collect(op_df, conf = conf)

@pytest.mark.parametrize('data_gen', [all_basic_struct_gen, StructGen([['child0', StructGen([['child1', byte_gen]])]]),
ArrayGen(IntegerGen()),
ArrayGen(
StructGen([['child0', StringGen()],
['child1',
StructGen([['child0', IntegerGen()]])]])),
decimal_struct_gen] + all_gen, ids=idfn)
@pytest.mark.parametrize('enable_vectorized_conf', enable_vectorized_confs, ids=idfn)
@allow_non_gpu('CollectLimitExec')
Expand Down Expand Up @@ -181,14 +186,19 @@ def n_fold(spark):

# This test doesn't allow negative scale for Decimals as ` df.write.mode('overwrite').parquet(data_path)`
# writes parquet which doesn't allow negative decimals
@pytest.mark.parametrize('data_gen', [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),
@pytest.mark.parametrize('data_gen', [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), ArrayGen(IntegerGen()),
ArrayGen(
StructGen([['child0', StringGen()],
['child1',
StructGen([['child0', IntegerGen()]])]])),
pytest.param(FloatGen(special_cases=[FLOAT_MIN, FLOAT_MAX, 0.0, 1.0, -1.0]), marks=[incompat]),
pytest.param(DoubleGen(special_cases=double_special_cases), marks=[incompat]),
BooleanGen(), DateGen(), TimestampGen(), decimal_gen_default, decimal_gen_scale_precision,
decimal_gen_same_scale_precision, decimal_gen_64bit], ids=idfn)
@pytest.mark.parametrize('ts_write', ['TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'])
@pytest.mark.parametrize('enable_vectorized', ['true', 'false'], ids=idfn)
@ignore_order
@allow_non_gpu("SortExec", "ShuffleExchangeExec", "RangePartitioning")
def test_cache_columnar(spark_tmp_path, data_gen, enable_vectorized, ts_write):
data_path_gpu = spark_tmp_path + '/PARQUET_DATA'
def read_parquet_cached(data_path):
Expand All @@ -208,7 +218,11 @@ def write_read_parquet_cached(spark):
assert_gpu_and_cpu_are_equal_collect(read_parquet_cached(data_path_gpu), conf)

@pytest.mark.parametrize('data_gen', [all_basic_struct_gen, StructGen([['child0', StructGen([['child1', byte_gen]])]]),
decimal_struct_gen]+ all_gen, ids=idfn)
decimal_struct_gen, ArrayGen(IntegerGen()),
ArrayGen(
StructGen([['child0', StringGen()],
['child1',
StructGen([['child0', IntegerGen()]])]]))]+ all_gen, ids=idfn)
@pytest.mark.parametrize('enable_vectorized_conf', enable_vectorized_confs, ids=idfn)
def test_cache_cpu_gpu_mixed(data_gen, enable_vectorized_conf):
def func(spark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class Spark311CDHShims extends SparkBaseShims {
super.getExecs ++ Seq(
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested()
.withPsNote(TypeEnum.DECIMAL,
"Negative scales aren't supported at the moment even with " +
"spark.sql.legacy.allowNegativeScaleOfDecimal set to true. This is because Parquet " +
"doesn't support negative scale for decimal values"),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested().withPsNote(TypeEnum.DECIMAL,
"Negative scales aren't supported at the moment even with " +
"spark.sql.legacy.allowNegativeScaleOfDecimal set to true. " +
"This is because Parquet doesn't support negative scale for decimal values"),
TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ class Spark320Shims extends Spark32XShims {
}),
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested(), TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
if (!scan.relation.cacheBuilder.serializer.isInstanceOf[ParquetCachedBatchSerializer]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,7 @@ class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer with Arm {

def isSupportedByCudf(dataType: DataType): Boolean = {
dataType match {
// TODO: when arrays are supported for cudf writes add it here.
// https://github.com/NVIDIA/spark-rapids/issues/2054
case a: ArrayType => isSupportedByCudf(a.elementType)
case s: StructType => s.forall(field => isSupportedByCudf(field.dataType))
case _ => GpuColumnVector.isNonNestedSupportedType(dataType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,12 @@ abstract class SparkBaseShims extends Spark30XShims {
}),
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested(), TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
if (!scan.relation.cacheBuilder.serializer
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,12 @@ abstract class SparkBaseShims extends Spark30XShims {
}),
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
+ TypeSig.ARRAY).nested(), TypeSig.all),
(scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) {
override def tagPlanForGpu(): Unit = {
if (!scan.relation.cacheBuilder.serializer
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
.isInstanceOf[com.nvidia.spark.ParquetCachedBatchSerializer]) {
willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used")
}
}
Expand Down

0 comments on commit 6b64965

Please sign in to comment.