Skip to content

Commit

Permalink
ARROW-11221: [Rust] DF Implement GROUP BY support for Float32/Float64
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed Jan 12, 2021
1 parent 8e5d09e commit f63a13c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 7 deletions.
1 change: 1 addition & 0 deletions rust/datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ tokio = { version = "0.2", features = ["macros", "blocking", "rt-core", "rt-thre
log = "^0.4"
md-5 = "^0.9.1"
sha2 = "^0.9.1"
ordered-float = "2.0"

[dev-dependencies]
rand = "0.8"
Expand Down
64 changes: 59 additions & 5 deletions rust/datafusion/src/physical_plan/group_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

//! Defines scalars used to construct groups, ex. in GROUP BY clauses.
use ordered_float::OrderedFloat;
use std::convert::{From, TryFrom};

use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;

/// Enumeration of types that can be used in a GROUP BY expression (all primitives except
/// for floating point numerics)
/// Enumeration of types that can be used in a GROUP BY expression
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub(crate) enum GroupByScalar {
Float32(OrderedFloat<f32>),
Float64(OrderedFloat<f64>),
UInt8(u8),
UInt16(u16),
UInt32(u32),
Expand All @@ -44,6 +46,12 @@ impl TryFrom<&ScalarValue> for GroupByScalar {

fn try_from(scalar_value: &ScalarValue) -> Result<Self> {
Ok(match scalar_value {
ScalarValue::Float32(Some(v)) => {
GroupByScalar::Float32(OrderedFloat::from(*v))
}
ScalarValue::Float64(Some(v)) => {
GroupByScalar::Float64(OrderedFloat::from(*v))
}
ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
Expand All @@ -53,7 +61,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
ScalarValue::Int8(None)
ScalarValue::Float32(None)
| ScalarValue::Float64(None)
| ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
Expand All @@ -80,6 +90,8 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
impl From<&GroupByScalar> for ScalarValue {
fn from(group_by_scalar: &GroupByScalar) -> Self {
match group_by_scalar {
GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
Expand All @@ -101,6 +113,48 @@ mod tests {

use crate::error::{DataFusionError, Result};

macro_rules! scalar_eq_test {
($TYPE:expr, $VALUE:expr) => {{
let scalar_value = $TYPE($VALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();

let scalar_value = $TYPE($VALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();

assert_eq!(a, b);
}};
}

#[test]
fn test_scalar_ne_non_std() -> Result<()> {
// Test only Scalars with non native Eq, Hash
scalar_eq_test!(ScalarValue::Float32, Some(1.0));
scalar_eq_test!(ScalarValue::Float64, Some(1.0));

Ok(())
}

macro_rules! scalar_ne_test {
($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{
let scalar_value = $TYPE($LVALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();

let scalar_value = $TYPE($RVALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();

assert_ne!(a, b);
}};
}

#[test]
fn test_scalar_eq_non_std() -> Result<()> {
// Test only Scalars with non native Eq, Hash
scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0));
scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0));

Ok(())
}

#[test]
fn from_scalar_holding_none() -> Result<()> {
let scalar_value = ScalarValue::Int8(None);
Expand All @@ -120,14 +174,14 @@ mod tests {
#[test]
fn from_scalar_unsupported() -> Result<()> {
// Use any ScalarValue type not supported by GroupByScalar.
let scalar_value = ScalarValue::Float32(Some(1.1));
let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
let result = GroupByScalar::try_from(&scalar_value);

match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from(
"Cannot convert a ScalarValue with associated DataType Float32"
"Cannot convert a ScalarValue with associated DataType LargeUtf8"
)
),
_ => panic!("Unexpected result"),
Expand Down
12 changes: 10 additions & 2 deletions rust/datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use arrow::{
array::{
ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
compute,
};
Expand Down Expand Up @@ -685,6 +685,14 @@ fn create_batch_from_map(
// 2.
let mut groups = (0..num_group_expr)
.map(|i| match &group_by_values[i] {
GroupByScalar::Float32(n) => {
Arc::new(Float32Array::from(vec![(*n).into()] as Vec<f32>))
as ArrayRef
}
GroupByScalar::Float64(n) => {
Arc::new(Float64Array::from(vec![(*n).into()] as Vec<f64>))
as ArrayRef
}
GroupByScalar::Int8(n) => {
Arc::new(Int8Array::from(vec![*n])) as ArrayRef
}
Expand Down

0 comments on commit f63a13c

Please sign in to comment.