Skip to content

Commit

Permalink
ARROW-11221: [Rust] DF Implement GROUP BY support for Float32/Float64
Browse files Browse the repository at this point in the history
Rust doesn't provide Eq, Hash for f32/f64 types inside stdlib, it's why I am using an external library called ordered-float which implements this traits. It's better to use external library instead of implementing own inside this repository.

Closes #9175 from ovr/issue-11221

Authored-by: Dmitry Patsura <zaets28rus@gmail.com>
Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
ovr authored and alamb committed Jan 14, 2021
1 parent 96430cc commit 1393188
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 8 deletions.
1 change: 1 addition & 0 deletions rust/datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ tokio = { version = "0.2", features = ["macros", "blocking", "rt-core", "rt-thre
log = "^0.4"
md-5 = "^0.9.1"
sha2 = "^0.9.1"
ordered-float = "2.0"

[dev-dependencies]
rand = "0.8"
Expand Down
64 changes: 59 additions & 5 deletions rust/datafusion/src/physical_plan/group_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

//! Defines scalars used to construct groups, ex. in GROUP BY clauses.
use ordered_float::OrderedFloat;
use std::convert::{From, TryFrom};

use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;

/// Enumeration of types that can be used in a GROUP BY expression (all primitives except
/// for floating point numerics)
/// Enumeration of types that can be used in a GROUP BY expression
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub(crate) enum GroupByScalar {
Float32(OrderedFloat<f32>),
Float64(OrderedFloat<f64>),
UInt8(u8),
UInt16(u16),
UInt32(u32),
Expand All @@ -44,6 +46,12 @@ impl TryFrom<&ScalarValue> for GroupByScalar {

fn try_from(scalar_value: &ScalarValue) -> Result<Self> {
Ok(match scalar_value {
ScalarValue::Float32(Some(v)) => {
GroupByScalar::Float32(OrderedFloat::from(*v))
}
ScalarValue::Float64(Some(v)) => {
GroupByScalar::Float64(OrderedFloat::from(*v))
}
ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
Expand All @@ -53,7 +61,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
ScalarValue::Int8(None)
ScalarValue::Float32(None)
| ScalarValue::Float64(None)
| ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
Expand All @@ -80,6 +90,8 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
impl From<&GroupByScalar> for ScalarValue {
fn from(group_by_scalar: &GroupByScalar) -> Self {
match group_by_scalar {
GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
Expand All @@ -101,6 +113,48 @@ mod tests {

use crate::error::{DataFusionError, Result};

macro_rules! scalar_eq_test {
($TYPE:expr, $VALUE:expr) => {{
let scalar_value = $TYPE($VALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();

let scalar_value = $TYPE($VALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();

assert_eq!(a, b);
}};
}

#[test]
fn test_scalar_ne_non_std() -> Result<()> {
// Test only Scalars with non native Eq, Hash
scalar_eq_test!(ScalarValue::Float32, Some(1.0));
scalar_eq_test!(ScalarValue::Float64, Some(1.0));

Ok(())
}

macro_rules! scalar_ne_test {
($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{
let scalar_value = $TYPE($LVALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();

let scalar_value = $TYPE($RVALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();

assert_ne!(a, b);
}};
}

#[test]
fn test_scalar_eq_non_std() -> Result<()> {
// Test only Scalars with non native Eq, Hash
scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0));
scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0));

Ok(())
}

#[test]
fn from_scalar_holding_none() -> Result<()> {
let scalar_value = ScalarValue::Int8(None);
Expand All @@ -120,14 +174,14 @@ mod tests {
#[test]
fn from_scalar_unsupported() -> Result<()> {
// Use any ScalarValue type not supported by GroupByScalar.
let scalar_value = ScalarValue::Float32(Some(1.1));
let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
let result = GroupByScalar::try_from(&scalar_value);

match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from(
"Cannot convert a ScalarValue with associated DataType Float32"
"Cannot convert a ScalarValue with associated DataType LargeUtf8"
)
),
_ => panic!("Unexpected result"),
Expand Down
21 changes: 19 additions & 2 deletions rust/datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use arrow::{
array::{
ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
compute,
};
Expand All @@ -48,6 +48,7 @@ use super::{
};
use ahash::RandomState;
use hashbrown::HashMap;
use ordered_float::OrderedFloat;

use arrow::array::{TimestampMicrosecondArray, TimestampNanosecondArray};
use async_trait::async_trait;
Expand Down Expand Up @@ -685,6 +686,14 @@ fn create_batch_from_map(
// 2.
let mut groups = (0..num_group_expr)
.map(|i| match &group_by_values[i] {
GroupByScalar::Float32(n) => {
Arc::new(Float32Array::from(vec![(*n).into()] as Vec<f32>))
as ArrayRef
}
GroupByScalar::Float64(n) => {
Arc::new(Float64Array::from(vec![(*n).into()] as Vec<f64>))
as ArrayRef
}
GroupByScalar::Int8(n) => {
Arc::new(Int8Array::from(vec![*n])) as ArrayRef
}
Expand Down Expand Up @@ -776,6 +785,14 @@ pub(crate) fn create_group_by_values(
for i in 0..group_by_keys.len() {
let col = &group_by_keys[i];
match col.data_type() {
DataType::Float32 => {
let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
vec[i] = GroupByScalar::Float32(OrderedFloat::from(array.value(row)))
}
DataType::Float64 => {
let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
vec[i] = GroupByScalar::Float64(OrderedFloat::from(array.value(row)))
}
DataType::UInt8 => {
let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
vec[i] = GroupByScalar::UInt8(array.value(row))
Expand Down
10 changes: 9 additions & 1 deletion rust/datafusion/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! into a set of partitions.
use arrow::{
array::{ArrayRef, UInt64Builder},
array::{ArrayRef, Float32Array, Float64Array, UInt64Builder},
compute,
};
use arrow::{
Expand Down Expand Up @@ -393,6 +393,14 @@ pub(crate) fn create_key(
vec.clear();
for col in group_by_keys {
match col.data_type() {
DataType::Float32 => {
let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
}
DataType::Float64 => {
let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
}
DataType::UInt8 => {
let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
Expand Down
16 changes: 16 additions & 0 deletions rust/datafusion/tests/aggregate_floats.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
c1,c2
0.00001,0.000000000001
0.00002,0.000000000002
0.00002,0.000000000002
0.00003,0.000000000003
0.00003,0.000000000003
0.00003,0.000000000003
0.00004,0.000000000004
0.00004,0.000000000004
0.00004,0.000000000004
0.00004,0.000000000004
0.00005,0.000000000005
0.00005,0.000000000005
0.00005,0.000000000005
0.00005,0.000000000005
0.00005,0.000000000005
57 changes: 57 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,48 @@ async fn csv_query_group_by_int_min_max() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_float32() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_floats_csv(&mut ctx)?;

let sql =
"SELECT COUNT(*) as cnt, c1 FROM aggregate_floats GROUP BY c1 ORDER BY cnt DESC";
let actual = execute(&mut ctx, sql).await;

let expected = vec![
vec!["5", "0.00005"],
vec!["4", "0.00004"],
vec!["3", "0.00003"],
vec!["2", "0.00002"],
vec!["1", "0.00001"],
];
assert_eq!(expected, actual);

Ok(())
}

#[tokio::test]
async fn csv_query_group_by_float64() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_floats_csv(&mut ctx)?;

let sql =
"SELECT COUNT(*) as cnt, c2 FROM aggregate_floats GROUP BY c2 ORDER BY cnt DESC";
let actual = execute(&mut ctx, sql).await;

let expected = vec![
vec!["5", "0.000000000005"],
vec!["4", "0.000000000004"],
vec!["3", "0.000000000003"],
vec!["2", "0.000000000002"],
vec!["1", "0.000000000001"],
];
assert_eq!(expected, actual);

Ok(())
}

#[tokio::test]
async fn csv_query_group_by_two_columns() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down Expand Up @@ -1325,6 +1367,21 @@ fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
Ok(())
}

fn register_aggregate_floats_csv(ctx: &mut ExecutionContext) -> Result<()> {
// It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Float32, false),
Field::new("c2", DataType::Float64, false),
]));

ctx.register_csv(
"aggregate_floats",
"tests/aggregate_floats.csv",
CsvReadOptions::new().schema(&schema),
)?;
Ok(())
}

fn register_alltypes_parquet(ctx: &mut ExecutionContext) {
let testdata = arrow::util::test_util::parquet_test_data();
ctx.register_parquet(
Expand Down

0 comments on commit 1393188

Please sign in to comment.