Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-11221: [Rust] DF Implement GROUP BY support for Float32/Float64 #9175

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ tokio = { version = "0.2", features = ["macros", "blocking", "rt-core", "rt-thre
log = "^0.4"
md-5 = "^0.9.1"
sha2 = "^0.9.1"
ordered-float = "2.0"

[dev-dependencies]
rand = "0.8"
Expand Down
64 changes: 59 additions & 5 deletions rust/datafusion/src/physical_plan/group_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

//! Defines scalars used to construct groups, ex. in GROUP BY clauses.

use ordered_float::OrderedFloat;
use std::convert::{From, TryFrom};

use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;

/// Enumeration of types that can be used in a GROUP BY expression (all primitives except
/// for floating point numerics)
/// Enumeration of types that can be used in a GROUP BY expression
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub(crate) enum GroupByScalar {
Float32(OrderedFloat<f32>),
Float64(OrderedFloat<f64>),
UInt8(u8),
UInt16(u16),
UInt32(u32),
Expand All @@ -44,6 +46,12 @@ impl TryFrom<&ScalarValue> for GroupByScalar {

fn try_from(scalar_value: &ScalarValue) -> Result<Self> {
Ok(match scalar_value {
ScalarValue::Float32(Some(v)) => {
GroupByScalar::Float32(OrderedFloat::from(*v))
}
ScalarValue::Float64(Some(v)) => {
GroupByScalar::Float64(OrderedFloat::from(*v))
}
ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
Expand All @@ -53,7 +61,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
ScalarValue::Int8(None)
ScalarValue::Float32(None)
| ScalarValue::Float64(None)
| ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
Expand All @@ -80,6 +90,8 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
impl From<&GroupByScalar> for ScalarValue {
fn from(group_by_scalar: &GroupByScalar) -> Self {
match group_by_scalar {
GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
Expand All @@ -101,6 +113,48 @@ mod tests {

use crate::error::{DataFusionError, Result};

macro_rules! scalar_eq_test {
($TYPE:expr, $VALUE:expr) => {{
let scalar_value = $TYPE($VALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();

let scalar_value = $TYPE($VALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();

assert_eq!(a, b);
}};
}

#[test]
fn test_scalar_ne_non_std() -> Result<()> {
// Test only Scalars with non native Eq, Hash
scalar_eq_test!(ScalarValue::Float32, Some(1.0));
scalar_eq_test!(ScalarValue::Float64, Some(1.0));

Ok(())
}

macro_rules! scalar_ne_test {
($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{
let scalar_value = $TYPE($LVALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();

let scalar_value = $TYPE($RVALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();

assert_ne!(a, b);
}};
}

#[test]
fn test_scalar_eq_non_std() -> Result<()> {
// Test only Scalars with non native Eq, Hash
scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0));
scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0));

Ok(())
}

#[test]
fn from_scalar_holding_none() -> Result<()> {
let scalar_value = ScalarValue::Int8(None);
Expand All @@ -120,14 +174,14 @@ mod tests {
#[test]
fn from_scalar_unsupported() -> Result<()> {
// Use any ScalarValue type not supported by GroupByScalar.
let scalar_value = ScalarValue::Float32(Some(1.1));
let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
let result = GroupByScalar::try_from(&scalar_value);

match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from(
"Cannot convert a ScalarValue with associated DataType Float32"
"Cannot convert a ScalarValue with associated DataType LargeUtf8"
)
),
_ => panic!("Unexpected result"),
Expand Down
21 changes: 19 additions & 2 deletions rust/datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;
use arrow::{
array::{
ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
compute,
};
Expand All @@ -48,6 +48,7 @@ use super::{
};
use ahash::RandomState;
use hashbrown::HashMap;
use ordered_float::OrderedFloat;

use arrow::array::{TimestampMicrosecondArray, TimestampNanosecondArray};
use async_trait::async_trait;
Expand Down Expand Up @@ -685,6 +686,14 @@ fn create_batch_from_map(
// 2.
let mut groups = (0..num_group_expr)
.map(|i| match &group_by_values[i] {
GroupByScalar::Float32(n) => {
Arc::new(Float32Array::from(vec![(*n).into()] as Vec<f32>))
as ArrayRef
}
GroupByScalar::Float64(n) => {
Arc::new(Float64Array::from(vec![(*n).into()] as Vec<f64>))
as ArrayRef
}
GroupByScalar::Int8(n) => {
Arc::new(Int8Array::from(vec![*n])) as ArrayRef
}
Expand Down Expand Up @@ -776,6 +785,14 @@ pub(crate) fn create_group_by_values(
for i in 0..group_by_keys.len() {
let col = &group_by_keys[i];
match col.data_type() {
DataType::Float32 => {
Copy link
Contributor Author

@ovr ovr Jan 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated PR, forget about this place. I am testing it now with real DB example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding an end-to-end sql test in https://github.com/apache/arrow/blob/master/rust/datafusion/tests/sql.rs to make sure the plumbing is all hooked up correctly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb I've added tests, but I found a strange bug probably with count.

012c1ac#diff-4b06103cf2132b1ab297fbe8cd42622ecbe1109ea26df7d8b358fa36d739549cR357

Thanks

let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
vec[i] = GroupByScalar::Float32(OrderedFloat::from(array.value(row)))
}
DataType::Float64 => {
let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
vec[i] = GroupByScalar::Float64(OrderedFloat::from(array.value(row)))
}
DataType::UInt8 => {
let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
vec[i] = GroupByScalar::UInt8(array.value(row))
Expand Down
10 changes: 9 additions & 1 deletion rust/datafusion/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! into a set of partitions.

use arrow::{
array::{ArrayRef, UInt64Builder},
array::{ArrayRef, Float32Array, Float64Array, UInt64Builder},
compute,
};
use arrow::{
Expand Down Expand Up @@ -393,6 +393,14 @@ pub(crate) fn create_key(
vec.clear();
for col in group_by_keys {
match col.data_type() {
DataType::Float32 => {
let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
}
DataType::Float64 => {
let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
}
DataType::UInt8 => {
let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
Expand Down
16 changes: 16 additions & 0 deletions rust/datafusion/tests/aggregate_floats.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
c1,c2
0.00001,0.000000000001
0.00002,0.000000000002
0.00002,0.000000000002
0.00003,0.000000000003
0.00003,0.000000000003
0.00003,0.000000000003
0.00004,0.000000000004
0.00004,0.000000000004
0.00004,0.000000000004
0.00004,0.000000000004
0.00005,0.000000000005
0.00005,0.000000000005
0.00005,0.000000000005
0.00005,0.000000000005
0.00005,0.000000000005
57 changes: 57 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,48 @@ async fn csv_query_group_by_int_min_max() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_float32() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_floats_csv(&mut ctx)?;

let sql =
"SELECT COUNT(*) as cnt, c1 FROM aggregate_floats GROUP BY c1 ORDER BY cnt DESC";
let actual = execute(&mut ctx, sql).await;

let expected = vec![
vec!["5", "0.00005"],
vec!["4", "0.00004"],
vec!["3", "0.00003"],
vec!["2", "0.00002"],
vec!["1", "0.00001"],
];
assert_eq!(expected, actual);

Ok(())
}

#[tokio::test]
async fn csv_query_group_by_float64() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_floats_csv(&mut ctx)?;

let sql =
"SELECT COUNT(*) as cnt, c2 FROM aggregate_floats GROUP BY c2 ORDER BY cnt DESC";
let actual = execute(&mut ctx, sql).await;

let expected = vec![
vec!["5", "0.000000000005"],
vec!["4", "0.000000000004"],
vec!["3", "0.000000000003"],
vec!["2", "0.000000000002"],
vec!["1", "0.000000000001"],
];
assert_eq!(expected, actual);

Ok(())
}

#[tokio::test]
async fn csv_query_group_by_two_columns() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down Expand Up @@ -1325,6 +1367,21 @@ fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
Ok(())
}

fn register_aggregate_floats_csv(ctx: &mut ExecutionContext) -> Result<()> {
// It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Float32, false),
Field::new("c2", DataType::Float64, false),
]));

ctx.register_csv(
"aggregate_floats",
"tests/aggregate_floats.csv",
CsvReadOptions::new().schema(&schema),
)?;
Ok(())
}

fn register_alltypes_parquet(ctx: &mut ExecutionContext) {
let testdata = arrow::util::test_util::parquet_test_data();
ctx.register_parquet(
Expand Down