Skip to content

Commit

Permalink
Support MapType in joins (#3011)
Browse files Browse the repository at this point in the history
* Update joins to support map types for non-key columns

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Revert temp debug logging

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Remove comment

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* scalastyle and docs

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Add psnote explaining that ARRAY, MAP, STRUCT cannot be used as join keys

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Jul 28, 2021
1 parent f9895b3 commit 471df42
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 31 deletions.
26 changes: 13 additions & 13 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -582,7 +582,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
Expand Down Expand Up @@ -627,9 +627,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -696,9 +696,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -719,9 +719,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (Cannot be used as join key; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):

single_level_array_gens_no_decimal = [ArrayGen(sub_gen) for sub_gen in all_basic_gens + [null_gen]]

map_string_string_gen = [MapGen(StringGen(pattern='key_[0-9]', nullable=False), StringGen())]

# Be careful to not make these too large of data generation takes for ever
# This is only a few nested array gens, because nesting can be very deep
nested_array_gens_sample = [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10),
Expand Down
33 changes: 33 additions & 0 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def do_join(spark):
conf.update(_sortmerge_join_conf)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_sortmerge_join_map(data_gen, join_type, batch_size):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 500, 500)
return left.join(right, left.key == right.r_key, join_type)
conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
conf.update(_sortmerge_join_conf)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)

# For floating point values the normalization is done using a higher order function. We could probably work around this
# for now it falls back to the CPU
@allow_non_gpu('SortMergeJoinExec', 'SortExec', 'KnownFloatingPointNormalized', 'ArrayTransform', 'LambdaFunction',
Expand Down Expand Up @@ -154,6 +166,16 @@ def do_join(spark):
return left.join(right, left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_hash_join_conf)

@validate_execs_in_gpu_plan('GpuShuffledHashJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_hash_join_map(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 50, 500)
return left.join(right, left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_hash_join_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
Expand All @@ -178,6 +200,17 @@ def do_join(spark):
return left.join(broadcast(right), left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_broadcast_join_right_table_map(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 500, 500)
return left.join(broadcast(right), left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [all_basic_struct_gen], ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,32 @@ abstract class SparkBaseShims extends SparkShims {
GpuOverrides.exec[SortMergeJoinExec](
"Sort merge join, replacing with shuffled hash join",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP)
.withPsNote(TypeEnum.ARRAY, "Cannot be used as join key")
.withPsNote(TypeEnum.STRUCT, "Cannot be used as join key")
.withPsNote(TypeEnum.MAP, "Cannot be used as join key")
.nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)),
GpuOverrides.exec[BroadcastHashJoinExec](
"Implementation of join using broadcast data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP)
.withPsNote(TypeEnum.ARRAY, "Cannot be used as join key")
.withPsNote(TypeEnum.STRUCT, "Cannot be used as join key")
.withPsNote(TypeEnum.MAP, "Cannot be used as join key")
.nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)),
GpuOverrides.exec[ShuffledHashJoinExec](
"Implementation of join using hashed shuffled data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP)
.withPsNote(TypeEnum.ARRAY, "Cannot be used as join key")
.withPsNote(TypeEnum.STRUCT, "Cannot be used as join key")
.withPsNote(TypeEnum.MAP, "Cannot be used as join key")
.nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r)),
GpuOverrides.exec[ArrowEvalPythonExec](
"The backend of the Scalar Pandas UDFs. Accelerates the data transfer between the" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,20 +383,33 @@ class Spark311Shims extends Spark301Shims {
GpuOverrides.exec[SortMergeJoinExec](
"Sort merge join, replacing with shuffled hash join",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP)
.withPsNote(TypeEnum.ARRAY, "Cannot be used as join key")
.withPsNote(TypeEnum.STRUCT, "Cannot be used as join key")
.withPsNote(TypeEnum.MAP, "Cannot be used as join key")
.nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)),
GpuOverrides.exec[BroadcastHashJoinExec](
"Implementation of join using broadcast data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP)
.withPsNote(TypeEnum.ARRAY, "Cannot be used as join key")
.withPsNote(TypeEnum.STRUCT, "Cannot be used as join key")
.withPsNote(TypeEnum.MAP, "Cannot be used as join key")
.nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL)
, TypeSig.all),
(join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)),
GpuOverrides.exec[ShuffledHashJoinExec](
"Implementation of join using hashed shuffled data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP)
.withPsNote(TypeEnum.ARRAY, "Cannot be used as join key")
.withPsNote(TypeEnum.STRUCT, "Cannot be used as join key")
.withPsNote(TypeEnum.MAP, "Cannot be used as join key")
.nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2991,8 +2991,8 @@ object GpuOverrides {
exec[BroadcastExchangeExec](
"The backend for broadcast exchange of data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL)
, TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(exchange, conf, p, r) => new GpuBroadcastMeta(exchange, conf, p, r)),
exec[BroadcastNestedLoopJoinExec](
"Implementation of join using brute force",
Expand Down Expand Up @@ -3042,8 +3042,8 @@ object GpuOverrides {
"The backend for the sort operator",
// The SortOrder TypeSig will govern what types can actually be used as sorting key data type.
// The types below are allowed as inputs and outputs.
ExecChecks(pluginSupportedOrderableSig + (TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks(pluginSupportedOrderableSig + (TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.MAP).nested(), TypeSig.all),
(sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r)),
exec[ExpandExec](
"The backend for the expand operator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, Scalar, Table}
import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback

import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DateType, DecimalType, IntegerType, LongType, NullType, NumericType, StringType, StructType, TimestampType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DateType, DecimalType, IntegerType, LongType, MapType, NullType, NumericType, StringType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand Down Expand Up @@ -465,7 +465,7 @@ object JoinGathererImpl {
}
case _: NumericType | DateType | TimestampType | BooleanType | NullType =>
Some(GpuColumnVector.getNonNestedRapidsType(dt).getSizeInBytes * 8 + 1)
case StringType | BinaryType | ArrayType(_, _) if nullValueCalc =>
case StringType | BinaryType | ArrayType(_, _) | MapType(_, _, _) if nullValueCalc =>
// Single offset value and a validity value
Some((DType.INT32.getSizeInBytes * 8) + 1)
case x if nullValueCalc =>
Expand Down

0 comments on commit 471df42

Please sign in to comment.