-
Notifications
You must be signed in to change notification settings - Fork 237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds pre/post steps for merge and update aggregate #3417
Conversation
Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
Co-authored-by: Nghia Truong <nghiatruong.vn@gmail.com>
build |
sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
Show resolved
Hide resolved
@@ -995,7 +1019,7 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( | |||
override def convertToGpu(): GpuExec = { | |||
GpuHashAggregateExec( | |||
requiredChildDistributionExpressions.map(_.map(_.convertToGpu())), | |||
groupingExpressions.map(_.convertToGpu()), | |||
groupingExpressions.map(_.convertToGpu()).asInstanceOf[Seq[NamedExpression]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of type erasure (how java does generics) the sequence could contain things that are not NamedExpressions and the cast would happily pass. I would prefer to have us do something more like
groupingExpressions.map(_.convertToGpu().asInstanceOf[NamedExpression])
I see tests passing in 7.3 and 8.2 for databricks and locally. I found a leak, and addressed it with: a626e65. |
build |
It seems that we have enough material to merge this PR. @abellina before merging this please remove the standard deviation stuffs ( |
I started doing this here d4807c1, but re-reviewing my code I had comments around this because some of the changes are only for the M2 aggregates. Do you want me to remove comments as well, or rework them to not refer to M2? At this point, you could let this in with the M2 reference implementation and change what you need to change, or just take the branch and do something on your own. Or merge with the comments that point to future aggregates, if you are going to put your patch up soon. |
build |
Thanks, I just merged the code you removed and will continue working on them. |
This PR is prequel/continuation of #3373.
The work here adds part of the work by @ttnghia, and I added it to make sense of the extra processing added to the aggregate function expressions, and to be able to test it all.
The refactor adds a pre/post step to updates and merges. An example of a "pre merge" step is casting, or creating a struct, as is needed by
MERGE_M2
. The "pre update" case is not overloaded, and so it's the attribute reference as is (a pass-through projection). The "post update" step can be used to cast (as is done inGpuM2
) in the update, and then later in the merge, the "post merge" where a struct is decomposed, and fields casted, as expected by Spark.These steps allow 1 set of casting to be removed from the grouped aggregates in
aggregates.scala
. I did not mess with reduction aggregates in this PR, I can do that next. It was not required for thestddev_pop
work.An untested (other than some quick examples in a shell) impl of
stddev_pop
is adapted from Supportstddev
andvariance
aggregations families [databricks] #3373 to demonstrate how the buffers are put together to produce the final result (sqrt(M2/n)
).The code here really needs testing, as such there is no
GpuOverrides
node added in this PR. In other words, the code is there, but it is not being actively used. The two new projections for pre/post steps are getting executed by existing aggs.I tested the diffs with the integration tests locally, and in databricks 8.2. I have not run in databricks 7.3 yet, but I wanted to get this up to get some 👀. Note that on databricks 8.2 I am noticing other issues with the tests (as did @revans2), especially when we run with the parallel setting. Tests were failing due to some unrelated bugs, so I'll re-run tests and comment here tomorrow, and we'll need some follow ups for that.