-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix hash_aggregate test failures due to TypedImperativeAggregate #3178
Conversation
build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The one-line change seems OK to me, but I'll defer to @abellina or @kuhushukla since they've done more with distinct handling.
sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
main comment I have is I wonder if it's time for us to check whether we are indeed in databricks. We do have the shims, and we should be able to check or add a check for "isDatabricks". The comments are helpful @sperlingxx.
// The 3rd stage of AggWithOneDistinct, which combines (partial) reduce-side | ||
// nonDistinctAggExpressions and map-side distinctAggExpressions. For this stage, we need to | ||
// switch the position of distinctAttributes and nonDistinctAttributes. | ||
if (modeInfo.uniqueModes.length > 1 && aggregateExpressions.exists(_.isDistinct)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me a little nervous that we are missing something. The Spark aggregation code does not look at distinct at all. It really just looks at the individual modes for each operation. Why is it that we need to do this to get the aggregation right, but the Spark code does not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For AggWithOneDistinct
, the Spark plans 4-stage stack of AggregateExec
. Each stage owns an unique Modes:
- Stage 1:
Partial
mode, only includes nonDistinct ones - Stage 2:
PartialMerge
mode, only includes nonDistinct ones - Stage 3:
PartialMerge
mode for nonDistinct ones andPartial
mode for Distinct ones - Stage 4:
Final
mode for both nonDistinct and Distinct AggregateExpressions
In contrast, Databricks runtime seems to apply a quite different planning strategy to AggWithOneDistinct
. With the dumped plan trees, we infer Databricks runtime only plans 2-stage stack for AggWithOneDistinct
: Map-stage and Reduce-stage.
- Map-stage:
Partial
mode, only includes nonDistinct ones - Reduce-stage:
Final
mode for nonDistinct ones andComplete
mode for Distinct ones
Apparently, the Map-stage corresponds to Stage 1 and Stage 2; the Reduce-stage corresponds to Stage 3 and Stage 4.
The condition here was used to match Stage 3, so it checked whether modeInfo contains both PartialMerge
and Partial
. Currently, we want to adapt Databricks runtime. In terms of Reduce-stage, the input projections of Reduce-stage are exactly same as Stage 3, though they contain different AggregateModes. Therefore, we change the condition here to match the Reduce-stage of Databrick runtimes as well as the Stage 3 of Spark. In fact, the condition modeInfo.uniqueModes.length > 1
along is enough to distinguish Stage 3 and Reduce-stage from other stages. The latter condition aggregateExpressions.exists(_.isDistinct)
is to increase the robustness in case of some unknown special cases.
In addition, the input projections for Stage 1 fully fits the Map-stage of Databricks runtime. We don't need to change anything to adapt Databricks runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, the condition like (modeInfo.hasPartialMergeMode && modeInfo.hasPartialMode) || (modeInfo.hasFinalMode && modeInfo.hasCompleteMode)
may look more straightforward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am okay with this as a short term fix. The problem is not with your logic. The problem is that we keep hacking special cases onto something that should not need them.
Each aggregation comes with a mode. Each mode tells the aggregation what to do as a part of that stage. Originally the code assumed that there would only ever be one mode for all of the aggregations. I thought we had ripped that all out and each aggregation does the right thing.
To successfully do an aggregation there are a few steps used.
- Initial projection to get input data setup properly.
- Initial aggregation to produce partial result(s)
- Merge aggregation to combine partial results (This requires that the input schema and the output schema be the same)
- Final projection to take the combined partial results and produce a final result.
In general the steps take the pattern 1, 2, 3*, 4. Which means 1, 2 and 4 are required and step 3 can be done as often as needed because the input and output schemas are the same.
Step 4 requires that all of the data for a given group by key is on the same task and has been merged into a single output row. There are several different ways to do this, which is why we end up with several aggregation modes.
Partial
mode means that we do Step 1 and Step 2. Then we can do Step 3 as many times as needed depending on how we are doing memory management, and how many batches are needed.PartialMerge
mode means we can do Step 3 at least once and possibly more times depending on how we are doing memory management and how many batches are needed.Final
mode means that we do the same steps as withPartialMerge
but do Step 4 when we are done doing the partial merges.Complete
mode is something only Databricks does, but it essentially means we do Step 1, Step 2, Step 3* (depending on memory management requirements), and Step 4 all at once.
I know that the details are a lot more complicated, but conceptually it should not be too difficult. I will file a follow on issue for us to figure this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main ask is to not do this wholesale, assuming that a hash aggregate exec has a certain shape. If this function could decide per aggregate expression mode what the right binding should be, it should be more robust to new aggregate exec setups that mix and match modes (if we encounter new ones). That said, I don't think this is your fault as the setupReferences
code was built that way, it needs to be reworked separately.
I am removing my approval for now. Bobby's question is good.
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
build |
This reverts commit 1658125.
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
build |
Fixes #3131
#3131 reported two kinds of bugs introduced by the PR for GpuCollectList as TypedImperativeAggregate.
One of them is merely a bug on test codes: missing
allow_non_gpu
withCoalesce
, which failed test_hash_groupby_collect_with_multi_distinct_fallback becauseCoalesce
expressions inserted by RewriteDistinctAggregates are not Columnar supported. This bug can be fixed via filling the missingallow_non_gpu
value (Coalese).Another one is about incorrect input projections in the final stage of aggregation. After dumping the plan tree of original CPU plan, we found Databricks runtime overridden the planning of single distinct aggregate. In terms of Apache Spark, the method
AggUtils.planAggregateWithOneDistinct
would create 4 stages of physical plans for each aggregation (logical plan). Meanwhile, the dumped result of Databricks runtime only contains 2 stages (joint-Partial and joint-Final). For query, the plan tree turned out to be:
According to the dumped plan tree, we can infer that the stage 1 and stage 2 produced by Spark's AggregateWithOneDistinct are fused into a joint Partial stage in Databricks runtime. Similarly, the stage 3 and stage 4 are fused into a joint Merge stage.
In this PR, we tried to fix the bug without increasing the complexity of Databricks shim. To adapt the optimization of AggregateWithOneDistinct in Databricks runtime, we changed the condition of
boundInputReferences
to extend this ability on handling stages including both nonDistinctAggExpressions and DistinctAggExpressions. And this approach has been manually verified with Databricks runtime 3.0.1.