diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 5008f49250b0..ffb51b2e8a1f 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::sync::Arc; use std::task::{Context, Poll}; +use std::vec; use ahash::RandomState; use futures::{ @@ -32,6 +33,7 @@ use crate::physical_plan::{ Accumulator, AggregateExpr, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, SQLMetric, }; +use crate::scalar::ScalarValue; use arrow::{ array::{Array, UInt32Builder}, @@ -623,10 +625,12 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec) -> Result<( DataType::UInt64 => { dictionary_create_key_for_col::(col, row, vec)?; } - _ => return Err(DataFusionError::Internal(format!( + _ => { + return Err(DataFusionError::Internal(format!( "Unsupported GROUP BY type (dictionary index type not supported creating key) {}", col.data_type(), - ))), + ))) + } }, _ => { // This is internal because we should have caught this before. @@ -957,20 +961,6 @@ impl RecordBatchStream for HashAggregateStream { } } -/// Given Vec>, concatenates the inners `Vec` into `ArrayRef`, returning `Vec` -/// This assumes that `arrays` is not empty. -fn concatenate(arrays: Vec>) -> ArrowResult> { - (0..arrays[0].len()) - .map(|column| { - let array_list = arrays - .iter() - .map(|a| a[column].as_ref()) - .collect::>(); - compute::concat(&array_list) - }) - .collect::>>() -} - /// Create a RecordBatch with all group keys and accumulator' states or values. fn create_batch_from_map( mode: &AggregateMode, @@ -978,84 +968,72 @@ fn create_batch_from_map( num_group_expr: usize, output_schema: &Schema, ) -> ArrowResult { - // 1. for each key - // 2. create single-row ArrayRef with all group expressions - // 3. create single-row ArrayRef with all aggregate states or values - // 4. collect all in a vector per key of vec, vec[i][j] - // 5. concatenate the arrays over the second index [j] into a single vec. - let arrays = accumulators - .iter() - .map(|(_, (group_by_values, accumulator_set, _))| { - // 2. - let mut groups = (0..num_group_expr) - .map(|i| match &group_by_values[i] { - GroupByScalar::Float32(n) => { - Arc::new(Float32Array::from(vec![(*n).into()] as Vec)) - as ArrayRef - } - GroupByScalar::Float64(n) => { - Arc::new(Float64Array::from(vec![(*n).into()] as Vec)) - as ArrayRef - } - GroupByScalar::Int8(n) => { - Arc::new(Int8Array::from(vec![*n])) as ArrayRef - } - GroupByScalar::Int16(n) => Arc::new(Int16Array::from(vec![*n])), - GroupByScalar::Int32(n) => Arc::new(Int32Array::from(vec![*n])), - GroupByScalar::Int64(n) => Arc::new(Int64Array::from(vec![*n])), - GroupByScalar::UInt8(n) => Arc::new(UInt8Array::from(vec![*n])), - GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])), - GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])), - GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])), - GroupByScalar::Utf8(str) => { - Arc::new(StringArray::from(vec![&***str])) - } - GroupByScalar::LargeUtf8(str) => { - Arc::new(LargeStringArray::from(vec![&***str])) - } - GroupByScalar::Boolean(b) => Arc::new(BooleanArray::from(vec![*b])), - GroupByScalar::TimeMillisecond(n) => { - Arc::new(TimestampMillisecondArray::from(vec![*n])) - } - GroupByScalar::TimeMicrosecond(n) => { - Arc::new(TimestampMicrosecondArray::from(vec![*n])) - } - GroupByScalar::TimeNanosecond(n) => { - Arc::new(TimestampNanosecondArray::from_vec(vec![*n], None)) - } - GroupByScalar::Date32(n) => Arc::new(Date32Array::from(vec![*n])), - }) - .collect::>(); + if accumulators.is_empty() { + return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned()))); + } + let (_, (_, accs, _)) = accumulators.iter().next().unwrap(); + let mut acc_data_types: Vec = vec![]; - // 3. - groups.extend( - finalize_aggregation(accumulator_set, mode) - .map_err(DataFusionError::into_arrow_external_error)?, - ); + // Calculate number/shape of state arrays + match mode { + AggregateMode::Partial => { + for acc in accs.iter() { + let state = acc + .state() + .map_err(DataFusionError::into_arrow_external_error)?; + acc_data_types.push(state.len()); + } + } + AggregateMode::Final | AggregateMode::FinalPartitioned => { + acc_data_types = vec![1; accs.len()]; + } + } - Ok(groups) + let mut columns = (0..num_group_expr) + .map(|i| { + ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (group_by_values, _, _))| ScalarValue::from(&group_by_values[i]), + )) }) - // 4. - .collect::>>>()?; + .collect::>>() + .map_err(|x| x.into_arrow_external_error())?; + + // add state / evaluated arrays + for (x, &state_len) in acc_data_types.iter().enumerate() { + for y in 0..state_len { + match mode { + AggregateMode::Partial => { + let res = ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (_, accumulator, _))| { + let x = accumulator[x].state().unwrap(); + x[y].clone() + }, + )) + .map_err(DataFusionError::into_arrow_external_error)?; + + columns.push(res); + } + AggregateMode::Final | AggregateMode::FinalPartitioned => { + let res = ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (_, accumulator, _))| accumulator[x].evaluate().unwrap(), + )) + .map_err(DataFusionError::into_arrow_external_error)?; + columns.push(res); + } + } + } + } - let batch = if !arrays.is_empty() { - // 5. - let columns = concatenate(arrays)?; + // cast output if needed (e.g. for types like Dictionary where + // the intermediate GroupByScalar type was not the same as the + // output + let columns = columns + .iter() + .zip(output_schema.fields().iter()) + .map(|(col, desired_field)| cast(col, desired_field.data_type())) + .collect::>>()?; - // cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - let columns = columns - .iter() - .zip(output_schema.fields().iter()) - .map(|(col, desired_field)| cast(col, desired_field.data_type())) - .collect::>>()?; - - RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)? - } else { - RecordBatch::new_empty(Arc::new(output_schema.to_owned())) - }; - Ok(batch) + RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns) } fn create_accumulators( diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index f3fa5b2c5de5..ac7deeed22c7 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -21,10 +21,10 @@ use crate::error::{DataFusionError, Result}; use arrow::{ array::*, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, - Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }, }; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -311,7 +311,7 @@ impl ScalarValue { /// ]; /// /// // Build an Array from the list of ScalarValues - /// let array = ScalarValue::iter_to_array(scalars.iter()) + /// let array = ScalarValue::iter_to_array(scalars.into_iter()) /// .unwrap(); /// /// let expected: ArrayRef = std::sync::Arc::new( @@ -324,8 +324,8 @@ impl ScalarValue { /// /// assert_eq!(&array, &expected); /// ``` - pub fn iter_to_array<'a>( - scalars: impl IntoIterator, + pub fn iter_to_array( + scalars: impl IntoIterator, ) -> Result { let mut scalars = scalars.into_iter().peekable(); @@ -344,10 +344,10 @@ impl ScalarValue { macro_rules! build_array_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let values = scalars + let array = scalars .map(|sv| { if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(*v) + Ok(v) } else { Err(DataFusionError::Internal(format!( "Inconsistent types in ScalarValue::iter_to_array. \ @@ -356,9 +356,8 @@ impl ScalarValue { ))) } }) - .collect::>>()?; + .collect::>()?; - let array: $ARRAY_TY = values.iter().collect(); Arc::new(array) } }}; @@ -369,7 +368,7 @@ impl ScalarValue { macro_rules! build_array_string { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let values = scalars + let array = scalars .map(|sv| { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) @@ -381,19 +380,74 @@ impl ScalarValue { ))) } }) - .collect::>>()?; - - // it is annoying that one can not create - // StringArray et al directly from iter of &String, - // requiring this map to &str - let values = values.iter().map(|s| s.as_ref()); - - let array: $ARRAY_TY = values.collect(); + .collect::>()?; Arc::new(array) } }}; } + macro_rules! build_array_list_primitive { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ + Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( + scalars.into_iter().map(|x| match x { + ScalarValue::List(xs, _) => xs.map(|x| { + x.iter() + .map(|x| match x { + ScalarValue::$SCALAR_TY(i) => *i, + sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", data_type, sv), + }) + .collect::>>() + }), + sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", data_type, sv), + }), + )) + }}; + } + + macro_rules! build_array_list_string { + ($BUILDER:ident, $SCALAR_TY:ident) => {{ + let mut builder = ListBuilder::new($BUILDER::new(0)); + + for scalar in scalars.into_iter() { + match scalar { + ScalarValue::List(Some(xs), _) => { + for s in xs { + match s { + ScalarValue::$SCALAR_TY(Some(val)) => { + builder.values().append_value(val)?; + } + ScalarValue::$SCALAR_TY(None) => { + builder.values().append_null()?; + } + sv => return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ))), + } + } + builder.append(true)?; + } + ScalarValue::List(None, _) => { + builder.append(false)?; + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))) + } + } + } + + Arc::new(builder.finish()) + + }} + } + let array: ArrayRef = match &data_type { DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), @@ -430,6 +484,42 @@ impl ScalarValue { DataType::Interval(IntervalUnit::YearMonth) => { build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) } + DataType::List(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list_primitive!(Int8Type, Int8, i8) + } + DataType::List(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list_primitive!(Int16Type, Int16, i16) + } + DataType::List(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list_primitive!(Int32Type, Int32, i32) + } + DataType::List(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list_primitive!(Int64Type, Int64, i64) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list_primitive!(UInt8Type, UInt8, u8) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list_primitive!(UInt16Type, UInt16, u16) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list_primitive!(UInt32Type, UInt32, u32) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list_primitive!(UInt64Type, UInt64, u64) + } + DataType::List(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list_primitive!(Float32Type, Float32, f32) + } + DataType::List(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list_primitive!(Float64Type, Float64, f64) + } + DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { + build_array_list_string!(StringBuilder, Utf8) + } + DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { + build_array_list_string!(LargeStringBuilder, LargeUtf8) + } _ => { return Err(DataFusionError::Internal(format!( "Unsupported creation of {:?} array from ScalarValue {:?}", @@ -1102,7 +1192,7 @@ mod tests { let scalars: Vec<_> = $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -1119,7 +1209,7 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -1136,7 +1226,7 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: $ARRAYTYPE = $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); @@ -1210,7 +1300,7 @@ mod tests { fn scalar_iter_to_array_empty() { let scalars = vec![] as Vec; - let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); assert!( result .to_string() @@ -1226,7 +1316,7 @@ mod tests { // If the scalar values are not all the correct type, error here let scalars: Vec = vec![Boolean(Some(true)), Int32(Some(5))]; - let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), "{}", result); }