Skip to content

Commit

Permalink
Minor: add partial assertion for skip aggregation probe (#12640)
Browse files Browse the repository at this point in the history
* add partial assertion for skip aggr probe and improve comments.

* fix fmt.

* use pattern match for aggr mode to improve readability.

* only check `should_skip_aggregation` in partial aggr.

* make some condition check assert check.

* clearer way to distinguish partial and terminals branches.
  • Loading branch information
Rachelint authored Sep 30, 2024
1 parent da70fab commit f1aa27f
Showing 1 changed file with 55 additions and 8 deletions.
63 changes: 55 additions & 8 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream {
match &self.exec_state {
ExecutionState::ReadingInput => 'reading_input: {
match ready!(self.input.poll_next_unpin(cx)) {
// new batch to aggregate
Some(Ok(batch)) => {
// New batch to aggregate in partial aggregation operator
Some(Ok(batch)) if self.mode == AggregateMode::Partial => {
let timer = elapsed_compute.timer();
let input_rows = batch.num_rows();

// Make sure we have enough capacity for `batch`, otherwise spill
extract_ok!(self.spill_previous_if_necessary(&batch));

// Do the grouping
extract_ok!(self.group_aggregate_batch(batch));

Expand Down Expand Up @@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream {

timer.done();
}

// New batch to aggregate in terminal aggregation operator
// (Final/FinalPartitioned/Single/SinglePartitioned)
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();

// Make sure we have enough capacity for `batch`, otherwise spill
extract_ok!(self.spill_previous_if_necessary(&batch));

// Do the grouping
extract_ok!(self.group_aggregate_batch(batch));

// If we can begin emitting rows, do so,
// otherwise keep consuming input
assert!(!self.input_done);

// If the number of group values equals or exceeds the soft limit,
// emit all groups and switch to producing output
if self.hit_soft_group_limit() {
timer.done();
extract_ok!(self.set_input_done_and_produce_output());
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
self.exec_state = ExecutionState::ProducingOutput(batch);
timer.done();
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}

timer.done();
}

// Found error from input stream
Some(Err(e)) => {
// inner had error, return to caller
return Poll::Ready(Some(Err(e)));
}

// Found end from input stream
None => {
// inner is done, emit all rows and switch to producing output
extract_ok!(self.set_input_done_and_produce_output());
Expand Down Expand Up @@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream {
(
if self.input_done {
ExecutionState::Done
} else if self.should_skip_aggregation() {
}
// In Partial aggregation, we also need to check
// if we should trigger partial skipping
else if self.mode == AggregateMode::Partial
&& self.should_skip_aggregation()
{
ExecutionState::SkippingAggregation
} else {
ExecutionState::ReadingInput
Expand Down Expand Up @@ -879,10 +920,10 @@ impl GroupedHashAggregateStream {
if self.group_values.len() > 0
&& batch.num_rows() > 0
&& matches!(self.group_ordering, GroupOrdering::None)
&& !matches!(self.mode, AggregateMode::Partial)
&& !self.spill_state.is_stream_merging
&& self.update_memory_reservation().is_err()
{
assert_ne!(self.mode, AggregateMode::Partial);
// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();
Expand Down Expand Up @@ -927,9 +968,9 @@ impl GroupedHashAggregateStream {
fn emit_early_if_necessary(&mut self) -> Result<()> {
if self.group_values.len() >= self.batch_size
&& matches!(self.group_ordering, GroupOrdering::None)
&& matches!(self.mode, AggregateMode::Partial)
&& self.update_memory_reservation().is_err()
{
assert_eq!(self.mode, AggregateMode::Partial);
let n = self.group_values.len() / self.batch_size * self.batch_size;
let batch = self.emit(EmitTo::First(n), false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
Expand Down Expand Up @@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream {
}

/// Updates skip aggregation probe state.
///
/// Notice: It should only be called in Partial aggregation
fn update_skip_aggregation_probe(&mut self, input_rows: usize) {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
// Skip aggregation probe is not supported if stream has any spills,
Expand All @@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream {

/// In case the probe indicates that aggregation may be
/// skipped, forces stream to produce currently accumulated output.
///
/// Notice: It should only be called in Partial aggregation
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
if probe.should_skip() {
Expand All @@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream {

/// Returns true if the aggregation probe indicates that aggregation
/// should be skipped.
///
/// Notice: It should only be called in Partial aggregation
fn should_skip_aggregation(&self) -> bool {
self.skip_aggregation_probe
.as_ref()
Expand Down

0 comments on commit f1aa27f

Please sign in to comment.