Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MapType in joins #3011

Merged
merged 5 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -582,7 +582,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
Expand Down Expand Up @@ -628,7 +628,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
Expand Down Expand Up @@ -697,7 +697,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
Expand All @@ -720,7 +720,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):

single_level_array_gens_no_decimal = [ArrayGen(sub_gen) for sub_gen in all_basic_gens + [null_gen]]

map_string_string_gen = [MapGen(StringGen(pattern='key_[0-9]', nullable=False), StringGen())]

# Be careful to not make these too large of data generation takes for ever
# This is only a few nested array gens, because nesting can be very deep
nested_array_gens_sample = [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10),
Expand Down
33 changes: 33 additions & 0 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def do_join(spark):
conf.update(_sortmerge_join_conf)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_sortmerge_join_map(data_gen, join_type, batch_size):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 500, 500)
return left.join(right, left.key == right.r_key, join_type)
conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
conf.update(_sortmerge_join_conf)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=conf)

# For floating point values the normalization is done using a higher order function. We could probably work around this
# for now it falls back to the CPU
@allow_non_gpu('SortMergeJoinExec', 'SortExec', 'KnownFloatingPointNormalized', 'ArrayTransform', 'LambdaFunction',
Expand Down Expand Up @@ -154,6 +166,16 @@ def do_join(spark):
return left.join(right, left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_hash_join_conf)

@validate_execs_in_gpu_plan('GpuShuffledHashJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_hash_join_map(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 50, 500)
return left.join(right, left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_hash_join_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
Expand All @@ -178,6 +200,17 @@ def do_join(spark):
return left.join(broadcast(right), left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', map_string_string_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_broadcast_join_right_table_map(data_gen, join_type):
def do_join(spark):
left, right = create_nested_df(spark, short_gen, data_gen, 500, 500)
return left.join(broadcast(right), left.key == right.r_key, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [all_basic_struct_gen], ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,20 @@ abstract class SparkBaseShims extends SparkShims {
GpuOverrides.exec[SortMergeJoinExec](
"Sort merge join, replacing with shuffled hash join",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
jlowe marked this conversation as resolved.
Show resolved Hide resolved
(join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)),
GpuOverrides.exec[BroadcastHashJoinExec](
"Implementation of join using broadcast data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)),
GpuOverrides.exec[ShuffledHashJoinExec](
"Implementation of join using hashed shuffled data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r)),
GpuOverrides.exec[ArrowEvalPythonExec](
"The backend of the Scalar Pandas UDFs. Accelerates the data transfer between the" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,20 +383,20 @@ class Spark311Shims extends Spark301Shims {
GpuOverrides.exec[SortMergeJoinExec](
"Sort merge join, replacing with shuffled hash join",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)),
GpuOverrides.exec[BroadcastHashJoinExec](
"Implementation of join using broadcast data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)),
GpuOverrides.exec[ShuffledHashJoinExec](
"Implementation of join using hashed shuffled data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
), TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2971,8 +2971,8 @@ object GpuOverrides {
exec[BroadcastExchangeExec](
"The backend for broadcast exchange of data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL)
, TypeSig.all),
TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL +
TypeSig.DECIMAL), TypeSig.all),
(exchange, conf, p, r) => new GpuBroadcastMeta(exchange, conf, p, r)),
exec[BroadcastNestedLoopJoinExec](
"Implementation of join using brute force",
Expand Down Expand Up @@ -3018,8 +3018,8 @@ object GpuOverrides {
"The backend for the sort operator",
// The SortOrder TypeSig will govern what types can actually be used as sorting key data type.
// The types below are allowed as inputs and outputs.
ExecChecks(pluginSupportedOrderableSig + (TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ExecChecks(pluginSupportedOrderableSig + (TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.MAP).nested(), TypeSig.all),
(sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r)),
exec[ExpandExec](
"The backend for the expand operator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, Scalar, Table}
import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback

import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DateType, DecimalType, IntegerType, LongType, NullType, NumericType, StringType, StructType, TimestampType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DateType, DecimalType, IntegerType, LongType, MapType, NullType, NumericType, StringType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand Down Expand Up @@ -465,7 +465,7 @@ object JoinGathererImpl {
}
case _: NumericType | DateType | TimestampType | BooleanType | NullType =>
Some(GpuColumnVector.getNonNestedRapidsType(dt).getSizeInBytes * 8 + 1)
case StringType | BinaryType | ArrayType(_, _) if nullValueCalc =>
case StringType | BinaryType | ArrayType(_, _) | MapType(_, _, _) if nullValueCalc =>
// Single offset value and a validity value
Some((DType.INT32.getSizeInBytes * 8) + 1)
case x if nullValueCalc =>
Expand Down