diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index 4f5f2dcf3efc3..426a43a24bb26 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -63,6 +63,7 @@ tokio = { version = "0.2", features = ["macros", "blocking", "rt-core", "rt-thre log = "^0.4" md-5 = "^0.9.1" sha2 = "^0.9.1" +ordered-float = "2.0" [dev-dependencies] rand = "0.8" diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs index 295eb462cca5c..3d02a8dacddc5 100644 --- a/rust/datafusion/src/physical_plan/group_scalar.rs +++ b/rust/datafusion/src/physical_plan/group_scalar.rs @@ -17,15 +17,17 @@ //! Defines scalars used to construct groups, ex. in GROUP BY clauses. +use ordered_float::OrderedFloat; use std::convert::{From, TryFrom}; use crate::error::{DataFusionError, Result}; use crate::scalar::ScalarValue; -/// Enumeration of types that can be used in a GROUP BY expression (all primitives except -/// for floating point numerics) +/// Enumeration of types that can be used in a GROUP BY expression #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub(crate) enum GroupByScalar { + Float32(OrderedFloat), + Float64(OrderedFloat), UInt8(u8), UInt16(u16), UInt32(u32), @@ -44,6 +46,12 @@ impl TryFrom<&ScalarValue> for GroupByScalar { fn try_from(scalar_value: &ScalarValue) -> Result { Ok(match scalar_value { + ScalarValue::Float32(Some(v)) => { + GroupByScalar::Float32(OrderedFloat::from(*v)) + } + ScalarValue::Float64(Some(v)) => { + GroupByScalar::Float64(OrderedFloat::from(*v)) + } ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v), ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v), ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v), @@ -53,7 +61,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar { ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())), - ScalarValue::Int8(None) + ScalarValue::Float32(None) + | ScalarValue::Float64(None) + | ScalarValue::Int8(None) | ScalarValue::Int16(None) | ScalarValue::Int32(None) | ScalarValue::Int64(None) @@ -80,6 +90,8 @@ impl TryFrom<&ScalarValue> for GroupByScalar { impl From<&GroupByScalar> for ScalarValue { fn from(group_by_scalar: &GroupByScalar) -> Self { match group_by_scalar { + GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())), + GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())), GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)), GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)), GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)), @@ -101,6 +113,48 @@ mod tests { use crate::error::{DataFusionError, Result}; + macro_rules! scalar_eq_test { + ($TYPE:expr, $VALUE:expr) => {{ + let scalar_value = $TYPE($VALUE); + let a = GroupByScalar::try_from(&scalar_value).unwrap(); + + let scalar_value = $TYPE($VALUE); + let b = GroupByScalar::try_from(&scalar_value).unwrap(); + + assert_eq!(a, b); + }}; + } + + #[test] + fn test_scalar_ne_non_std() -> Result<()> { + // Test only Scalars with non native Eq, Hash + scalar_eq_test!(ScalarValue::Float32, Some(1.0)); + scalar_eq_test!(ScalarValue::Float64, Some(1.0)); + + Ok(()) + } + + macro_rules! scalar_ne_test { + ($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{ + let scalar_value = $TYPE($LVALUE); + let a = GroupByScalar::try_from(&scalar_value).unwrap(); + + let scalar_value = $TYPE($RVALUE); + let b = GroupByScalar::try_from(&scalar_value).unwrap(); + + assert_ne!(a, b); + }}; + } + + #[test] + fn test_scalar_eq_non_std() -> Result<()> { + // Test only Scalars with non native Eq, Hash + scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0)); + scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0)); + + Ok(()) + } + #[test] fn from_scalar_holding_none() -> Result<()> { let scalar_value = ScalarValue::Int8(None); @@ -120,14 +174,14 @@ mod tests { #[test] fn from_scalar_unsupported() -> Result<()> { // Use any ScalarValue type not supported by GroupByScalar. - let scalar_value = ScalarValue::Float32(Some(1.1)); + let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string())); let result = GroupByScalar::try_from(&scalar_value); match result { Err(DataFusionError::Internal(error_message)) => assert_eq!( error_message, String::from( - "Cannot convert a ScalarValue with associated DataType Float32" + "Cannot convert a ScalarValue with associated DataType LargeUtf8" ) ), _ => panic!("Unexpected result"), diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index a902584b3919d..26e4ef0efc6ad 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -35,8 +35,8 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use arrow::{ array::{ - ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, compute, }; @@ -48,6 +48,7 @@ use super::{ }; use ahash::RandomState; use hashbrown::HashMap; +use ordered_float::OrderedFloat; use arrow::array::{TimestampMicrosecondArray, TimestampNanosecondArray}; use async_trait::async_trait; @@ -685,6 +686,14 @@ fn create_batch_from_map( // 2. let mut groups = (0..num_group_expr) .map(|i| match &group_by_values[i] { + GroupByScalar::Float32(n) => { + Arc::new(Float32Array::from(vec![(*n).into()] as Vec)) + as ArrayRef + } + GroupByScalar::Float64(n) => { + Arc::new(Float64Array::from(vec![(*n).into()] as Vec)) + as ArrayRef + } GroupByScalar::Int8(n) => { Arc::new(Int8Array::from(vec![*n])) as ArrayRef } @@ -776,6 +785,14 @@ pub(crate) fn create_group_by_values( for i in 0..group_by_keys.len() { let col = &group_by_keys[i]; match col.data_type() { + DataType::Float32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec[i] = GroupByScalar::Float32(OrderedFloat::from(array.value(row))) + } + DataType::Float64 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec[i] = GroupByScalar::Float64(OrderedFloat::from(array.value(row))) + } DataType::UInt8 => { let array = col.as_any().downcast_ref::().unwrap(); vec[i] = GroupByScalar::UInt8(array.value(row)) diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 7ed06792b0305..362785d41f4bb 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -19,7 +19,7 @@ //! into a set of partitions. use arrow::{ - array::{ArrayRef, UInt64Builder}, + array::{ArrayRef, Float32Array, Float64Array, UInt64Builder}, compute, }; use arrow::{ @@ -393,6 +393,14 @@ pub(crate) fn create_key( vec.clear(); for col in group_by_keys { match col.data_type() { + DataType::Float32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } + DataType::Float64 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend_from_slice(&array.value(row).to_le_bytes()); + } DataType::UInt8 => { let array = col.as_any().downcast_ref::().unwrap(); vec.extend_from_slice(&array.value(row).to_le_bytes()); diff --git a/rust/datafusion/tests/aggregate_floats.csv b/rust/datafusion/tests/aggregate_floats.csv new file mode 100644 index 0000000000000..86f5750c58b9a --- /dev/null +++ b/rust/datafusion/tests/aggregate_floats.csv @@ -0,0 +1,16 @@ +c1,c2 +0.00001,0.000000000001 +0.00002,0.000000000002 +0.00002,0.000000000002 +0.00003,0.000000000003 +0.00003,0.000000000003 +0.00003,0.000000000003 +0.00004,0.000000000004 +0.00004,0.000000000004 +0.00004,0.000000000004 +0.00004,0.000000000004 +0.00005,0.000000000005 +0.00005,0.000000000005 +0.00005,0.000000000005 +0.00005,0.000000000005 +0.00005,0.000000000005 \ No newline at end of file diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 2cfe40538c468..80ab70f07f43a 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -345,6 +345,48 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_group_by_float32() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_floats_csv(&mut ctx)?; + + let sql = + "SELECT COUNT(*) as cnt, c1 FROM aggregate_floats GROUP BY c1 ORDER BY cnt DESC"; + let actual = execute(&mut ctx, sql).await; + + let expected = vec![ + vec!["5", "0.00005"], + vec!["4", "0.00004"], + vec!["3", "0.00003"], + vec!["2", "0.00002"], + vec!["1", "0.00001"], + ]; + assert_eq!(expected, actual); + + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_float64() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_floats_csv(&mut ctx)?; + + let sql = + "SELECT COUNT(*) as cnt, c2 FROM aggregate_floats GROUP BY c2 ORDER BY cnt DESC"; + let actual = execute(&mut ctx, sql).await; + + let expected = vec![ + vec!["5", "0.000000000005"], + vec!["4", "0.000000000004"], + vec!["3", "0.000000000003"], + vec!["2", "0.000000000002"], + vec!["1", "0.000000000001"], + ]; + assert_eq!(expected, actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_two_columns() -> Result<()> { let mut ctx = ExecutionContext::new(); @@ -1325,6 +1367,21 @@ fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { Ok(()) } +fn register_aggregate_floats_csv(ctx: &mut ExecutionContext) -> Result<()> { + // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Float32, false), + Field::new("c2", DataType::Float64, false), + ])); + + ctx.register_csv( + "aggregate_floats", + "tests/aggregate_floats.csv", + CsvReadOptions::new().schema(&schema), + )?; + Ok(()) +} + fn register_alltypes_parquet(ctx: &mut ExecutionContext) { let testdata = arrow::util::test_util::parquet_test_data(); ctx.register_parquet(