diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index 13d3a61dac75..92f307432eca 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -26,7 +26,7 @@ use std::vec; use ahash::RandomState; use arrow::row::{OwnedRow, RowConverter, SortField}; use datafusion_physical_expr::hash_utils::create_hashes; -use futures::stream::BoxStream; +use futures::ready; use futures::stream::{Stream, StreamExt}; use crate::error::Result; @@ -75,19 +75,10 @@ use hashbrown::raw::RawTable; /// [Compact]: datafusion_row::layout::RowType::Compact /// [WordAligned]: datafusion_row::layout::RowType::WordAligned pub(crate) struct GroupedHashAggregateStream { - stream: BoxStream<'static, ArrowResult>, - schema: SchemaRef, -} - -/// Actual implementation of [`GroupedHashAggregateStream`]. -/// -/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem -/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with -/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamInner`]. -struct GroupedHashAggregateStreamInner { schema: SchemaRef, input: SendableRecordBatchStream, mode: AggregateMode, + exec_state: ExecutionState, normal_aggr_expr: Vec>, row_aggr_state: RowAggregationState, /// Aggregate expressions not supporting row accumulation @@ -115,6 +106,14 @@ struct GroupedHashAggregateStreamInner { indices: [Vec>; 2], } +#[derive(Debug)] +/// tracks what phase the aggregation is in +enum ExecutionState { + ReadingInput, + ProducingOutput, + Done, +} + fn aggr_state_schema(aggr_expr: &[Arc]) -> Result { let fields = aggr_expr .iter() @@ -201,9 +200,12 @@ impl GroupedHashAggregateStream { timer.done(); - let inner = GroupedHashAggregateStreamInner { + let exec_state = ExecutionState::ReadingInput; + + Ok(GroupedHashAggregateStream { schema: Arc::clone(&schema), mode, + exec_state, input, group_by, normal_aggr_expr, @@ -219,91 +221,75 @@ impl GroupedHashAggregateStream { batch_size, row_group_skip_position: 0, indices: [normal_agg_indices, row_agg_indices], - }; + }) + } +} - let stream = futures::stream::unfold(inner, |mut this| async move { - let elapsed_compute = this.baseline_metrics.elapsed_compute(); +impl Stream for GroupedHashAggregateStream { + type Item = ArrowResult; - loop { - let result: ArrowResult> = - match this.input.next().await { + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + match self.exec_state { + ExecutionState::ReadingInput => { + match ready!(self.input.poll_next_unpin(cx)) { + // new batch to aggregate Some(Ok(batch)) => { let timer = elapsed_compute.timer(); - let result = group_aggregate_batch( - &this.mode, - &this.random_state, - &this.group_by, - &this.normal_aggr_expr, - &mut this.row_accumulators, - &mut this.row_converter, - this.row_aggr_layout.clone(), - batch, - &mut this.row_aggr_state, - &this.normal_aggregate_expressions, - &this.row_aggregate_expressions, - ); - + let result = self.group_aggregate_batch(batch); timer.done(); // allocate memory // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with // overshooting a bit. Also this means we either store the whole record batch or not. - match result.and_then(|allocated| { - this.row_aggr_state.reservation.try_grow(allocated) - }) { - Ok(_) => continue, - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + let result = result.and_then(|allocated| { + self.row_aggr_state.reservation.try_grow(allocated) + }); + + if let Err(e) = result { + return Poll::Ready(Some(Err( + ArrowError::ExternalError(Box::new(e)), + ))); } } - Some(Err(e)) => Err(e), + // inner had error, return to caller + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + // inner is done, producing output None => { - let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = create_batch_from_map( - &this.mode, - &this.row_converter, - &this.row_aggr_schema, - this.batch_size, - this.row_group_skip_position, - &mut this.row_aggr_state, - &mut this.row_accumulators, - &this.schema, - &this.indices, - ); - - timer.done(); - result + self.exec_state = ExecutionState::ProducingOutput; } - }; - - this.row_group_skip_position += this.batch_size; - return match result { - Ok(Some(result)) => { - let batch = result.record_output(&this.baseline_metrics); - Some((Ok(batch), this)) } - Ok(None) => None, - Err(error) => Some((Err(error), this)), - }; - } - }); + } - // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream. - let stream = stream.fuse(); - let stream = Box::pin(stream); + ExecutionState::ProducingOutput => { + let timer = elapsed_compute.timer(); + let result = self.create_batch_from_map(); - Ok(Self { schema, stream }) - } -} + timer.done(); + self.row_group_skip_position += self.batch_size; -impl Stream for GroupedHashAggregateStream { - type Item = ArrowResult; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let this = &mut *self; - this.stream.poll_next_unpin(cx) + match result { + // made output + Ok(Some(result)) => { + let batch = result.record_output(&self.baseline_metrics); + return Poll::Ready(Some(Ok(batch))); + } + // end of output + Ok(None) => { + self.exec_state = ExecutionState::Done; + } + // error making output + Err(error) => return Poll::Ready(Some(Err(error))), + } + } + ExecutionState::Done => return Poll::Ready(None), + } + } } } @@ -313,222 +299,227 @@ impl RecordBatchStream for GroupedHashAggregateStream { } } -/// Perform group-by aggregation for the given [`RecordBatch`]. -/// -/// If successfull, this returns the additional number of bytes that were allocated during this process. -/// -/// TODO: Make this a member function of [`GroupedHashAggregateStream`] -#[allow(clippy::too_many_arguments)] -fn group_aggregate_batch( - mode: &AggregateMode, - random_state: &RandomState, - grouping_set: &PhysicalGroupBy, - normal_aggr_expr: &[Arc], - row_accumulators: &mut [RowAccumulatorItem], - row_converter: &mut RowConverter, - state_layout: Arc, - batch: RecordBatch, - aggr_state: &mut RowAggregationState, - normal_aggregate_expressions: &[Vec>], - row_aggregate_expressions: &[Vec>], -) -> Result { - // Evaluate the grouping expressions: - let group_by_values = evaluate_group_by(grouping_set, &batch)?; - // Keep track of memory allocated: - let mut allocated = 0usize; - let RowAggregationState { - map: row_map, - group_states: row_group_states, - .. - } = aggr_state; - - // Evaluate the aggregation expressions. - // We could evaluate them after the `take`, but since we need to evaluate all - // of them anyways, it is more performant to do it while they are together. - let row_aggr_input_values = evaluate_many(row_aggregate_expressions, &batch)?; - let normal_aggr_input_values = evaluate_many(normal_aggregate_expressions, &batch)?; - - let row_converter_size_pre = row_converter.size(); - for group_values in &group_by_values { - let group_rows = row_converter.convert_columns(group_values)?; - - // 1.1 construct the key from the group values - // 1.2 construct the mapping key if it does not exist - // 1.3 add the row' index to `indices` - - // track which entries in `aggr_state` have rows in this batch to aggregate - let mut groups_with_rows = vec![]; - - // 1.1 Calculate the group keys for the group values - let mut batch_hashes = vec![0; batch.num_rows()]; - create_hashes(group_values, random_state, &mut batch_hashes)?; - - for (row, hash) in batch_hashes.into_iter().enumerate() { - let entry = row_map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - let group_state = &row_group_states[*group_idx]; - group_rows.row(row) == group_state.group_by_values.row() - }); - - match entry { - // Existing entry for this group value - Some((_hash, group_idx)) => { - let group_state = &mut row_group_states[*group_idx]; +impl GroupedHashAggregateStream { + /// Perform group-by aggregation for the given [`RecordBatch`]. + /// + /// If successfull, this returns the additional number of bytes that were allocated during this process. + /// + fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result { + // Evaluate the grouping expressions: + let group_by_values = evaluate_group_by(&self.group_by, &batch)?; + // Keep track of memory allocated: + let mut allocated = 0usize; + let RowAggregationState { + map: row_map, + group_states: row_group_states, + .. + } = &mut self.row_aggr_state; + + // Evaluate the aggregation expressions. + // We could evaluate them after the `take`, but since we need to evaluate all + // of them anyways, it is more performant to do it while they are together. + let row_aggr_input_values = + evaluate_many(&self.row_aggregate_expressions, &batch)?; + let normal_aggr_input_values = + evaluate_many(&self.normal_aggregate_expressions, &batch)?; + + let row_converter_size_pre = self.row_converter.size(); + for group_values in &group_by_values { + let group_rows = self.row_converter.convert_columns(group_values)?; + + // 1.1 construct the key from the group values + // 1.2 construct the mapping key if it does not exist + // 1.3 add the row' index to `indices` + + // track which entries in `aggr_state` have rows in this batch to aggregate + let mut groups_with_rows = vec![]; + + // 1.1 Calculate the group keys for the group values + let mut batch_hashes = vec![0; batch.num_rows()]; + create_hashes(group_values, &self.random_state, &mut batch_hashes)?; + + for (row, hash) in batch_hashes.into_iter().enumerate() { + let entry = row_map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + let group_state = &row_group_states[*group_idx]; + group_rows.row(row) == group_state.group_by_values.row() + }); + + match entry { + // Existing entry for this group value + Some((_hash, group_idx)) => { + let group_state = &mut row_group_states[*group_idx]; + + // 1.3 + if group_state.indices.is_empty() { + groups_with_rows.push(*group_idx); + }; - // 1.3 - if group_state.indices.is_empty() { - groups_with_rows.push(*group_idx); - }; + group_state + .indices + .push_accounted(row as u32, &mut allocated); // remember this row + } + // 1.2 Need to create new entry + None => { + let accumulator_set = + aggregates::create_accumulators(&self.normal_aggr_expr)?; + // Add new entry to group_states and save newly created index + let group_state = RowGroupState { + group_by_values: group_rows.row(row).owned(), + aggregation_buffer: vec![ + 0; + self.row_aggr_layout + .fixed_part_width() + ], + accumulator_set, + indices: vec![row as u32], // 1.3 + }; + let group_idx = row_group_states.len(); + + // NOTE: do NOT include the `RowGroupState` struct size in here because this is captured by + // `group_states` (see allocation down below) + allocated += (std::mem::size_of::() + * group_state.group_by_values.as_ref().len()) + + (std::mem::size_of::() + * group_state.aggregation_buffer.capacity()) + + (std::mem::size_of::() + * group_state.indices.capacity()); + + // Allocation done by normal accumulators + allocated += (std::mem::size_of::>() + * group_state.accumulator_set.capacity()) + + group_state + .accumulator_set + .iter() + .map(|accu| accu.size()) + .sum::(); - group_state - .indices - .push_accounted(row as u32, &mut allocated); // remember this row - } - // 1.2 Need to create new entry - None => { - let accumulator_set = - aggregates::create_accumulators(normal_aggr_expr)?; - // Add new entry to group_states and save newly created index - let group_state = RowGroupState { - group_by_values: group_rows.row(row).owned(), - aggregation_buffer: vec![0; state_layout.fixed_part_width()], - accumulator_set, - indices: vec![row as u32], // 1.3 - }; - let group_idx = row_group_states.len(); - - // NOTE: do NOT include the `RowGroupState` struct size in here because this is captured by - // `group_states` (see allocation down below) - allocated += (std::mem::size_of::() - * group_state.group_by_values.as_ref().len()) - + (std::mem::size_of::() - * group_state.aggregation_buffer.capacity()) - + (std::mem::size_of::() * group_state.indices.capacity()); - - // Allocation done by normal accumulators - allocated += (std::mem::size_of::>() - * group_state.accumulator_set.capacity()) - + group_state - .accumulator_set - .iter() - .map(|accu| accu.size()) - .sum::(); - - // for hasher function, use precomputed hash value - row_map.insert_accounted( - (hash, group_idx), - |(hash, _group_index)| *hash, - &mut allocated, - ); - - row_group_states.push_accounted(group_state, &mut allocated); - - groups_with_rows.push(group_idx); - } - }; - } + // for hasher function, use precomputed hash value + row_map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + &mut allocated, + ); - // Collect all indices + offsets based on keys in this vec - let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0); - let mut offsets = vec![0]; - let mut offset_so_far = 0; - for &group_idx in groups_with_rows.iter() { - let indices = &row_group_states[group_idx].indices; - batch_indices.append_slice(indices); - offset_so_far += indices.len(); - offsets.push(offset_so_far); - } - let batch_indices = batch_indices.finish(); + row_group_states.push_accounted(group_state, &mut allocated); - let row_values = get_at_indices(&row_aggr_input_values, &batch_indices); - let normal_values = get_at_indices(&normal_aggr_input_values, &batch_indices); + groups_with_rows.push(group_idx); + } + }; + } - // 2.1 for each key in this batch - // 2.2 for each aggregation - // 2.3 `slice` from each of its arrays the keys' values - // 2.4 update / merge the accumulator with the values - // 2.5 clear indices - groups_with_rows - .iter() - .zip(offsets.windows(2)) - .try_for_each(|(group_idx, offsets)| { - let group_state = &mut row_group_states[*group_idx]; - // 2.2 - row_accumulators - .iter_mut() - .zip(row_values.iter()) - .map(|(accumulator, aggr_array)| { - ( - accumulator, - aggr_array - .iter() - .map(|array| { - // 2.3 - array.slice(offsets[0], offsets[1] - offsets[0]) - }) - .collect::>(), - ) - }) - .try_for_each(|(accumulator, values)| { - let mut state_accessor = - RowAccessor::new_from_layout(state_layout.clone()); - state_accessor - .point_to(0, group_state.aggregation_buffer.as_mut_slice()); - match mode { - AggregateMode::Partial => { - accumulator.update_batch(&values, &mut state_accessor) - } - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values, &mut state_accessor) - } - } - }) - // 2.5 - .and(Ok(()))?; - // normal accumulators - group_state - .accumulator_set - .iter_mut() - .zip(normal_values.iter()) - .map(|(accumulator, aggr_array)| { - ( - accumulator, - aggr_array - .iter() - .map(|array| { - // 2.3 - array.slice(offsets[0], offsets[1] - offsets[0]) - }) - .collect::>(), - ) - }) - .try_for_each(|(accumulator, values)| { - let size_pre = accumulator.size(); - let res = match mode { - AggregateMode::Partial => accumulator.update_batch(&values), - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values) + // Collect all indices + offsets based on keys in this vec + let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0); + let mut offsets = vec![0]; + let mut offset_so_far = 0; + for &group_idx in groups_with_rows.iter() { + let indices = &row_group_states[group_idx].indices; + batch_indices.append_slice(indices); + offset_so_far += indices.len(); + offsets.push(offset_so_far); + } + let batch_indices = batch_indices.finish(); + + let row_values = get_at_indices(&row_aggr_input_values, &batch_indices); + let normal_values = get_at_indices(&normal_aggr_input_values, &batch_indices); + + // 2.1 for each key in this batch + // 2.2 for each aggregation + // 2.3 `slice` from each of its arrays the keys' values + // 2.4 update / merge the accumulator with the values + // 2.5 clear indices + groups_with_rows + .iter() + .zip(offsets.windows(2)) + .try_for_each(|(group_idx, offsets)| { + let group_state = &mut row_group_states[*group_idx]; + // 2.2 + self.row_accumulators + .iter_mut() + .zip(row_values.iter()) + .map(|(accumulator, aggr_array)| { + ( + accumulator, + aggr_array + .iter() + .map(|array| { + // 2.3 + array.slice(offsets[0], offsets[1] - offsets[0]) + }) + .collect::>(), + ) + }) + .try_for_each(|(accumulator, values)| { + let mut state_accessor = RowAccessor::new_from_layout( + self.row_aggr_layout.clone(), + ); + state_accessor.point_to( + 0, + group_state.aggregation_buffer.as_mut_slice(), + ); + match self.mode { + AggregateMode::Partial => { + accumulator.update_batch(&values, &mut state_accessor) + } + AggregateMode::FinalPartitioned + | AggregateMode::Final => { + // note: the aggregation here is over states, not values, thus the merge + accumulator.merge_batch(&values, &mut state_accessor) + } } - }; - let size_post = accumulator.size(); - allocated += size_post.saturating_sub(size_pre); - res - }) - // 2.5 - .and({ - group_state.indices.clear(); - Ok(()) - })?; - - Ok::<(), DataFusionError>(()) - })?; + }) + // 2.5 + .and(Ok(()))?; + // normal accumulators + group_state + .accumulator_set + .iter_mut() + .zip(normal_values.iter()) + .map(|(accumulator, aggr_array)| { + ( + accumulator, + aggr_array + .iter() + .map(|array| { + // 2.3 + array.slice(offsets[0], offsets[1] - offsets[0]) + }) + .collect::>(), + ) + }) + .try_for_each(|(accumulator, values)| { + let size_pre = accumulator.size(); + let res = match self.mode { + AggregateMode::Partial => { + accumulator.update_batch(&values) + } + AggregateMode::FinalPartitioned + | AggregateMode::Final => { + // note: the aggregation here is over states, not values, thus the merge + accumulator.merge_batch(&values) + } + }; + let size_post = accumulator.size(); + allocated += size_post.saturating_sub(size_pre); + res + }) + // 2.5 + .and({ + group_state.indices.clear(); + Ok(()) + })?; + + Ok::<(), DataFusionError>(()) + })?; + } + allocated += self + .row_converter + .size() + .saturating_sub(row_converter_size_pre); + Ok(allocated) } - allocated += row_converter.size().saturating_sub(row_converter_size_pre); - Ok(allocated) } /// The state that is built for each output group. @@ -576,138 +567,131 @@ impl std::fmt::Debug for RowAggregationState { } } -/// Create a RecordBatch with all group keys and accumulator' states or values. -#[allow(clippy::too_many_arguments)] -fn create_batch_from_map( - mode: &AggregateMode, - converter: &RowConverter, - aggr_schema: &Schema, - batch_size: usize, - skip_items: usize, - row_aggr_state: &mut RowAggregationState, - row_accumulators: &mut [RowAccumulatorItem], - output_schema: &Schema, - // Stores the location of each accumulator in the output_schema - indices: &[Vec>], -) -> ArrowResult> { - if skip_items > row_aggr_state.group_states.len() { - return Ok(None); - } - if row_aggr_state.group_states.is_empty() { - let schema = Arc::new(output_schema.to_owned()); - return Ok(Some(RecordBatch::new_empty(schema))); - } - - let end_idx = min(skip_items + batch_size, row_aggr_state.group_states.len()); - let group_state_chunk = &row_aggr_state.group_states[skip_items..end_idx]; - - if group_state_chunk.is_empty() { - let schema = Arc::new(output_schema.to_owned()); - return Ok(Some(RecordBatch::new_empty(schema))); - } +impl GroupedHashAggregateStream { + /// Create a RecordBatch with all group keys and accumulator' states or values. + fn create_batch_from_map(&mut self) -> ArrowResult> { + let skip_items = self.row_group_skip_position; + if skip_items > self.row_aggr_state.group_states.len() { + return Ok(None); + } + if self.row_aggr_state.group_states.is_empty() { + let schema = self.schema.clone(); + return Ok(Some(RecordBatch::new_empty(schema))); + } - // Buffers for each distinct group (i.e. row accumulator memories) - let mut state_buffers = group_state_chunk - .iter() - .map(|gs| gs.aggregation_buffer.clone()) - .collect::>(); + let end_idx = min( + skip_items + self.batch_size, + self.row_aggr_state.group_states.len(), + ); + let group_state_chunk = &self.row_aggr_state.group_states[skip_items..end_idx]; - let output_fields = output_schema.fields(); - // Store row accumulator results (either final output or intermediate state): - let row_columns = match mode { - AggregateMode::Partial => { - read_as_batch(&state_buffers, aggr_schema, RowType::WordAligned) + if group_state_chunk.is_empty() { + let schema = self.schema.clone(); + return Ok(Some(RecordBatch::new_empty(schema))); } - AggregateMode::Final | AggregateMode::FinalPartitioned => { - let mut results = vec![]; - for (idx, acc) in row_accumulators.iter().enumerate() { - let mut state_accessor = - RowAccessor::new(aggr_schema, RowType::WordAligned); - let current = state_buffers - .iter_mut() - .map(|buffer| { - state_accessor.point_to(0, buffer); - acc.evaluate(&state_accessor) - }) - .collect::>>()?; - // Get corresponding field for row accumulator - let field = &output_fields[indices[1][idx].start]; - let result = if current.is_empty() { - Ok(arrow::array::new_empty_array(field.data_type())) - } else { - let item = ScalarValue::iter_to_array(current)?; - // cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - cast(&item, field.data_type()).map_err(DataFusionError::ArrowError) - }?; - results.push(result); + + // Buffers for each distinct group (i.e. row accumulator memories) + let mut state_buffers = group_state_chunk + .iter() + .map(|gs| gs.aggregation_buffer.clone()) + .collect::>(); + + let output_fields = self.schema.fields(); + // Store row accumulator results (either final output or intermediate state): + let row_columns = match self.mode { + AggregateMode::Partial => { + read_as_batch(&state_buffers, &self.row_aggr_schema, RowType::WordAligned) } - results - } - }; - - // Store normal accumulator results (either final output or intermediate state): - let mut columns = vec![]; - for (idx, &Range { start, end }) in indices[0].iter().enumerate() { - for (field_idx, field) in output_fields[start..end].iter().enumerate() { - let current = match mode { - AggregateMode::Partial => ScalarValue::iter_to_array( - group_state_chunk.iter().map(|row_group_state| { - row_group_state.accumulator_set[idx] - .state() - .map(|v| v[field_idx].clone()) - .expect("Unexpected accumulator state in hash aggregate") - }), - ), - AggregateMode::Final | AggregateMode::FinalPartitioned => { - ScalarValue::iter_to_array(group_state_chunk.iter().map( - |row_group_state| { + AggregateMode::Final | AggregateMode::FinalPartitioned => { + let mut results = vec![]; + for (idx, acc) in self.row_accumulators.iter().enumerate() { + let mut state_accessor = + RowAccessor::new(&self.row_aggr_schema, RowType::WordAligned); + let current = state_buffers + .iter_mut() + .map(|buffer| { + state_accessor.point_to(0, buffer); + acc.evaluate(&state_accessor) + }) + .collect::>>()?; + // Get corresponding field for row accumulator + let field = &output_fields[self.indices[1][idx].start]; + let result = if current.is_empty() { + Ok(arrow::array::new_empty_array(field.data_type())) + } else { + let item = ScalarValue::iter_to_array(current)?; + // cast output if needed (e.g. for types like Dictionary where + // the intermediate GroupByScalar type was not the same as the + // output + cast(&item, field.data_type()) + .map_err(DataFusionError::ArrowError) + }?; + results.push(result); + } + results + } + }; + + // Store normal accumulator results (either final output or intermediate state): + let mut columns = vec![]; + for (idx, &Range { start, end }) in self.indices[0].iter().enumerate() { + for (field_idx, field) in output_fields[start..end].iter().enumerate() { + let current = match self.mode { + AggregateMode::Partial => ScalarValue::iter_to_array( + group_state_chunk.iter().map(|row_group_state| { row_group_state.accumulator_set[idx] - .evaluate() + .state() + .map(|v| v[field_idx].clone()) .expect("Unexpected accumulator state in hash aggregate") - }, - )) - } - }?; - // Cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - let result = cast(¤t, field.data_type())?; - columns.push(result); + }), + ), + AggregateMode::Final | AggregateMode::FinalPartitioned => { + ScalarValue::iter_to_array(group_state_chunk.iter().map( + |row_group_state| { + row_group_state.accumulator_set[idx].evaluate().expect( + "Unexpected accumulator state in hash aggregate", + ) + }, + )) + } + }?; + // Cast output if needed (e.g. for types like Dictionary where + // the intermediate GroupByScalar type was not the same as the + // output + let result = cast(¤t, field.data_type())?; + columns.push(result); + } } - } - // Stores the group by fields - let group_buffers = group_state_chunk - .iter() - .map(|gs| gs.group_by_values.row()) - .collect::>(); - let mut output: Vec = converter.convert_rows(group_buffers)?; + // Stores the group by fields + let group_buffers = group_state_chunk + .iter() + .map(|gs| gs.group_by_values.row()) + .collect::>(); + let mut output: Vec = self.row_converter.convert_rows(group_buffers)?; - // The size of the place occupied by row and normal accumulators - let extra: usize = indices - .iter() - .flatten() - .map(|Range { start, end }| end - start) - .sum(); - let empty_arr = new_null_array(&DataType::Null, 1); - output.extend(std::iter::repeat(empty_arr).take(extra)); - - // Write results of both accumulator types to the corresponding location in - // the output schema: - let results = [columns.into_iter(), row_columns.into_iter()]; - for (outer, mut current) in results.into_iter().enumerate() { - for &Range { start, end } in indices[outer].iter() { - for item in output.iter_mut().take(end).skip(start) { - *item = current.next().expect("Columns cannot be empty"); + // The size of the place occupied by row and normal accumulators + let extra: usize = self + .indices + .iter() + .flatten() + .map(|Range { start, end }| end - start) + .sum(); + let empty_arr = new_null_array(&DataType::Null, 1); + output.extend(std::iter::repeat(empty_arr).take(extra)); + + // Write results of both accumulator types to the corresponding location in + // the output schema: + let results = [columns.into_iter(), row_columns.into_iter()]; + for (outer, mut current) in results.into_iter().enumerate() { + for &Range { start, end } in self.indices[outer].iter() { + for item in output.iter_mut().take(end).skip(start) { + *item = current.next().expect("Columns cannot be empty"); + } } } + Ok(Some(RecordBatch::try_new(self.schema.clone(), output)?)) } - Ok(Some(RecordBatch::try_new( - Arc::new(output_schema.to_owned()), - output, - )?)) } fn read_as_batch(rows: &[Vec], schema: &Schema, row_type: RowType) -> Vec {