Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support HashAggregate on struct and nested struct #3354

Merged
merged 3 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions if containing Array or Map as child;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -536,7 +536,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions if containing Array or Map as child;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -560,7 +560,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions if containing Array or Map as child;<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
74 changes: 74 additions & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,35 @@
('a', RepeatSeqGen(StringGen(pattern='[0-9]{0,30}'), length= 20)),
('b', IntegerGen()),
('c', NullGen())]
# grouping single-level structs
_grpkey_structs_with_non_nested_children = [
('a', RepeatSeqGen(StructGen([
['aa', IntegerGen()],
['ab', StringGen(pattern='[0-9]{0,30}')],
['ac', DecimalGen()]]), length=20)),
('b', IntegerGen()),
('c', NullGen())]
# grouping multiple-level structs
_grpkey_nested_structs = [
('a', RepeatSeqGen(StructGen([
['aa', IntegerGen()],
['ab', StringGen(pattern='[0-9]{0,30}')],
['ac', StructGen([['aca', LongGen()],
['acb', BooleanGen()],
['acc', StructGen([['acca', StringGen()]])]])]]),
length=20)),
('b', IntegerGen()),
('c', NullGen())]
# grouping multiple-level structs with arrays in children
_grpkey_nested_structs_with_array_child = [
('a', RepeatSeqGen(StructGen([
['aa', IntegerGen()],
['ab', ArrayGen(IntegerGen())],
['ac', ArrayGen(StructGen([['aca', LongGen()]]))]]),
length=20)),
('b', IntegerGen()),
('c', NullGen())]

# grouping NullType
_grpkey_nulls = [
('a', NullGen()),
Expand Down Expand Up @@ -687,6 +716,51 @@ def test_hash_agg_with_nan_keys(data_gen, parameterless):
'from hash_agg_table group by a',
_no_nans_float_conf)

@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_structs_with_non_nested_children,
_grpkey_nested_structs], ids=idfn)
def test_hash_agg_with_struct_keys(data_gen):
conf = _no_nans_float_conf.copy()
conf.update({'spark.sql.legacy.allowParameterlessCount': 'true'})
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen, length=1024),
"hash_agg_table",
'select a, '
'count(*) as count_stars, '
'count() as count_parameterless, '
'count(b) as count_bees, '
'sum(b) as sum_of_bees, '
'max(c) as max_seas, '
'min(c) as min_seas, '
'count(distinct c) as count_distinct_cees, '
'avg(c) as average_seas '
'from hash_agg_table group by a',
conf)

@ignore_order(local=True)
@allow_non_gpu('HashAggregateExec', 'Avg', 'Count', 'Max', 'Min', 'Sum', 'Average',
'Cast', 'Literal', 'Alias', 'AggregateExpression',
'ShuffleExchangeExec', 'HashPartitioning')
@pytest.mark.parametrize('data_gen', [_grpkey_nested_structs_with_array_child], ids=idfn)
def test_hash_agg_with_struct_of_array_fallback(data_gen):
conf = _no_nans_float_conf.copy()
conf.update({'spark.sql.legacy.allowParameterlessCount': 'true'})
assert_cpu_and_gpu_are_equal_sql_with_capture(
lambda spark : gen_df(spark, data_gen, length=100),
'select a, '
'count(*) as count_stars, '
'count() as count_parameterless, '
'count(b) as count_bees, '
'sum(b) as sum_of_bees, '
'max(c) as max_seas, '
'min(c) as min_seas, '
'avg(c) as average_seas '
'from hash_agg_table group by a',
"hash_agg_table",
exist_classes='HashAggregateExec',
non_exist_classes='GpuHashAggregateExec',
conf=conf)


@approximate_float
@ignore_order
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3227,7 +3227,8 @@ object GpuOverrides {
.nested()
.withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions")
.withPsNote(TypeEnum.MAP, "not allowed for grouping expressions")
.withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions"),
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
TypeSig.all),
(agg, conf, p, r) => new GpuHashAggregateMeta(agg, conf, p, r)),
exec[ObjectHashAggregateExec](
Expand All @@ -3238,7 +3239,8 @@ object GpuOverrides {
.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64)
.withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions")
.withPsNote(TypeEnum.MAP, "not allowed for grouping expressions")
.withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions"),
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
TypeSig.all),
(agg, conf, p, r) => new GpuObjectHashAggregateExecMeta(agg, conf, p, r)),
exec[SortAggregateExec](
Expand All @@ -3249,7 +3251,8 @@ object GpuOverrides {
.nested()
.withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions")
.withPsNote(TypeEnum.MAP, "not allowed for grouping expressions")
.withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions"),
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
TypeSig.all),
(agg, conf, p, r) => new GpuSortAggregateExecMeta(agg, conf, p, r)),
exec[SortExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,16 +896,17 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan](
resultExpressions

override def tagPlanForGpu(): Unit = {
agg.groupingExpressions
.find(_.dataType match {
case _@(ArrayType(_, _) | MapType(_, _, _)) | _@StructType(_) => true
case _ => false
})
.foreach(_ =>
willNotWorkOnGpu("Nested types in grouping expressions are not supported"))
if (agg.resultExpressions.isEmpty) {
willNotWorkOnGpu("result expressions is empty")
}
// We don't support Arrays and Maps as GroupBy keys yet, even they are nested in Structs. So,
// we need to run recursive type check on the structs.
val allTypesAreSupported = agg.groupingExpressions.forall(e =>
!TrampolineUtil.dataTypeExistsRecursively(e.dataType,
dt => dt.isInstanceOf[ArrayType] || dt.isInstanceOf[MapType]))
if (!allTypesAreSupported) {
willNotWorkOnGpu("ArrayTypes or MayTypes in grouping expressions are not supported")
}

tagForReplaceMode()

Expand Down