Skip to content

Commit

Permalink
CollectSet supports structs (#3700)
Browse files Browse the repository at this point in the history
Signed-off-by: Bobby Wang <wbo4958@gmail.com>
  • Loading branch information
wbo4958 authored Oct 4, 2021
1 parent 8a3f8ac commit db0c2d5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
8 changes: 4 additions & 4 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -14348,7 +14348,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -14367,7 +14367,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -14391,7 +14391,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -14410,7 +14410,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
36 changes: 36 additions & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,18 @@ def test_hash_reduction_pivot_without_nans(data_gen, conf):
('a', RepeatSeqGen(LongGen(), length=20)),
('b', value_gen)] for value_gen in _repeat_agg_column_for_collect_list_op]

_repeat_agg_column_for_collect_set_op = [
RepeatSeqGen(all_basic_struct_gen, length=15),
RepeatSeqGen(StructGen([['child0', all_basic_struct_gen]]), length=15)]

_gen_data_for_collect_set_op_for_unique_group_by_key = [[
('a', LongRangeGen()),
('b', value_gen)] for value_gen in _repeat_agg_column_for_collect_set_op]

_gen_data_for_collect_set_op = [[
('a', RepeatSeqGen(LongGen(), length=20)),
('b', value_gen)] for value_gen in _repeat_agg_column_for_collect_set_op]

# to avoid ordering issues with collect_list we do it all in a single task
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_list_op, ids=idfn)
Expand All @@ -456,6 +468,30 @@ def test_hash_groupby_collect_set(data_gen):
.groupby('a')
.agg(f.sort_array(f.collect_set('b')), f.count('b')))

@approximate_float
@ignore_order(local=True)
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op, ids=idfn)
@pytest.mark.xfail(reason="the result order from collect-set can not be ensured for CPU and GPU."
" We need to enable this after SortArray has supported on nested types."
" See https://github.com/NVIDIA/spark-rapids/issues/3715")
def test_hash_groupby_collect_set_on_nested_type(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.sort_array(f.collect_set('b')), f.count('b')))

# After https://github.com/NVIDIA/spark-rapids/issues/3715 is fixed, we should remove this test case
@approximate_float
@ignore_order(local=True)
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op_for_unique_group_by_key, ids=idfn)
def test_hash_groupby_collect_set_on_nested_type_for_unique_group_by(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.collect_set('b')))

@approximate_float
@ignore_order(local=True)
@incompat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3081,10 +3081,11 @@ object GpuOverrides extends Logging {
// Compared to CollectList, StructType is NOT in GpuCollectSet because underlying
// method drop_list_duplicates doesn't support nested types.
ExprChecks.aggNotReduction(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL +
TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(ParamCheck("input", TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL,
TypeSig.all))),
Seq(ParamCheck("input", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL +
TypeSig.STRUCT).nested(), TypeSig.all))),
(c, conf, p, r) => new TypedImperativeAggExprMeta[CollectSet](c, conf, p, r) {
override def convertToGpu(childExprs: Seq[Expression]): GpuExpression =
GpuCollectSet(childExprs.head, c.mutableAggBufferOffset, c.inputAggBufferOffset)
Expand Down

0 comments on commit db0c2d5

Please sign in to comment.