From b0b4dadde6ca496d3db068ca5692fcaf6805ea6a Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sun, 30 Jun 2024 00:29:54 +0800 Subject: [PATCH 1/5] tmp --- datafusion/common/src/scalar/mod.rs | 75 +++++++++++++++++++++++++++++ datafusion/sql/src/unparser/expr.rs | 2 +- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5b9c4a223de6..2f3cc715a22c 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -247,6 +247,8 @@ pub enum ScalarValue { /// Represents a single element [`StructArray`] as an [`ArrayRef`]. See /// [`ScalarValue`] for examples of how to create instances of this type. Struct(Arc), + /// Represents a single element [`MapArray`] as an [`ArrayRef`]. + Map(Arc), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -370,6 +372,8 @@ impl PartialEq for ScalarValue { (LargeList(_), _) => false, (Struct(v1), Struct(v2)) => v1.eq(v2), (Struct(_), _) => false, + (Map(v1), Map(v2)) => v1.eq(v2), + (Map(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -502,6 +506,10 @@ impl PartialOrd for ScalarValue { partial_cmp_struct(struct_arr1, struct_arr2) } (Struct(_), _) => None, + (Map(map_arr1), Map(map_arr2)) => { + partial_cmp_map(map_arr1, map_arr2) + }, + (Map(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -631,6 +639,34 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option, m2: &Arc) -> Option { + if m1.len() != m2.len() { + return None; + } + + if m1.data_type() != m2.data_type() { + return None; + } + + for col_index in 0..m1.len() { + let arr1 = m1.column(col_index); + let arr2 = m2.column(col_index); + + let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + } + Some(Ordering::Equal) +} + impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper @@ -696,6 +732,9 @@ impl std::hash::Hash for ScalarValue { Struct(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } + Map(arr) => { + hash_nested_array(arr.to_owned() as ArrayRef, state); + } Date32(v) => v.hash(state), Date64(v) => v.hash(state), Time32Second(v) => v.hash(state), @@ -1128,6 +1167,7 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.data_type().to_owned(), ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Struct(arr) => arr.data_type().to_owned(), + ScalarValue::Map(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1395,6 +1435,7 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Struct(arr) => arr.len() == arr.null_count(), + ScalarValue::Map(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -2161,6 +2202,9 @@ impl ScalarValue { ScalarValue::Struct(arr) => { Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } + ScalarValue::Map(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) } @@ -2790,6 +2834,9 @@ impl ScalarValue { ScalarValue::Struct(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } + ScalarValue::Map(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? } @@ -2925,6 +2972,7 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), + ScalarValue::Map(arr) => arr.get_array_memory_size(), ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) @@ -3387,6 +3435,29 @@ impl fmt::Display for ScalarValue { .join(",") )? } + ScalarValue::Map(map_arr) => { + // ScalarValue Map should always have a single element + assert_eq!(map_arr.len(), 1); + + if map_arr.null_count() == map_arr.len() { + write!(f, "NULL")?; + return Ok(()); + } + + // TODO: check + write!( + f, + "{{{}}}", + map_arr.iter().map(|struct_array| { + if let Some(arr) = struct_array { + array_value_to_string(arr.column(0), 0)?.to_string() + } + else { + "NULL".to_string() + } + }).collect::>().join(",") + )? + } ScalarValue::Union(val, _fields, _mode) => match val { Some((id, val)) => write!(f, "{}:{}", id, val)?, None => write!(f, "NULL")?, @@ -3480,6 +3551,10 @@ impl fmt::Debug for ScalarValue { .join(",") ) } + ScalarValue::Map(_) => { + // TODO: + write!(f, "Map({self}") + }, ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ad898de5987a..d101da8dd516 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1484,7 +1484,7 @@ mod tests { data_type: DataType::Decimal128(10, -2), }), r#"CAST(a AS DECIMAL(12,0))"#, - ), + ) ]; for (expr, expected) in tests { From c6f3b4caed58913afb8edbd2c6afc17acda89440 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Wed, 3 Jul 2024 00:20:17 +0800 Subject: [PATCH 2/5] introduce ScalarValue::Map --- datafusion/common/src/scalar/mod.rs | 112 ++++++++++++++---- .../proto/datafusion_common.proto | 3 +- datafusion/proto-common/src/from_proto/mod.rs | 4 +- .../proto-common/src/generated/pbjson.rs | 14 +++ .../proto-common/src/generated/prost.rs | 8 +- datafusion/proto-common/src/to_proto/mod.rs | 8 +- .../src/generated/datafusion_proto_common.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 24 ++++ datafusion/sql/src/unparser/expr.rs | 3 +- 9 files changed, 153 insertions(+), 29 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 2f3cc715a22c..768c446672ac 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -506,9 +506,7 @@ impl PartialOrd for ScalarValue { partial_cmp_struct(struct_arr1, struct_arr2) } (Struct(_), _) => None, - (Map(map_arr1), Map(map_arr2)) => { - partial_cmp_map(map_arr1, map_arr2) - }, + (Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2), (Map(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, @@ -649,8 +647,8 @@ fn partial_cmp_map(m1: &Arc, m2: &Arc) -> Option { } for col_index in 0..m1.len() { - let arr1 = m1.column(col_index); - let arr2 = m2.column(col_index); + let arr1 = m1.entries().column(col_index); + let arr2 = m2.entries().column(col_index); let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?; let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?; @@ -3305,6 +3303,12 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + DataType::Map(fields, sorted) => ScalarValue::Map( + new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1) + .as_map() + .to_owned() + .into(), + ), DataType::Union(fields, mode) => { ScalarValue::Union(None, fields.clone(), *mode) } @@ -3444,18 +3448,35 @@ impl fmt::Display for ScalarValue { return Ok(()); } - // TODO: check write!( f, - "{{{}}}", - map_arr.iter().map(|struct_array| { - if let Some(arr) = struct_array { - array_value_to_string(arr.column(0), 0)?.to_string() - } - else { - "NULL".to_string() - } - }).collect::>().join(",") + "[{}]", + map_arr + .iter() + .map(|struct_array| { + if let Some(arr) = struct_array { + let mut buffer = VecDeque::new(); + for i in 0..arr.len() { + let key = + array_value_to_string(arr.column(0), i).unwrap(); + let value = + array_value_to_string(arr.column(1), i).unwrap(); + buffer.push_back(format!("{}:{}", key, value)); + } + format!( + "{{{}}}", + buffer + .into_iter() + .collect::>() + .join(",") + .as_str() + ) + } else { + "NULL".to_string() + } + }) + .collect::>() + .join(",") )? } ScalarValue::Union(val, _fields, _mode) => match val { @@ -3551,10 +3572,35 @@ impl fmt::Debug for ScalarValue { .join(",") ) } - ScalarValue::Map(_) => { - // TODO: - write!(f, "Map({self}") - }, + ScalarValue::Map(map_arr) => { + // ScalarValue Map should always have a single element + assert_eq!(map_arr.len(), 1); + write!( + f, + "Map([{}])", + map_arr + .iter() + .map(|struct_array| { + if let Some(arr) = struct_array { + let buffer: Vec = (0..arr.len()) + .map(|i| { + let key = array_value_to_string(arr.column(0), i) + .unwrap(); + let value = + array_value_to_string(arr.column(1), i) + .unwrap(); + format!("{}:{}", key, value) + }) + .collect(); + format!("{{{}}}", buffer.join(",")) + } else { + "NULL".to_string() + } + }) + .collect::>() + .join(",") + ) + } ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3643,7 +3689,7 @@ mod tests { use super::*; use crate::cast::{ - as_string_array, as_struct_array, as_uint32_array, as_uint64_array, + as_string_array, as_struct_array, as_uint32_array, as_uint64_array, as_map_array, }; use crate::assert_batches_eq; @@ -3657,6 +3703,32 @@ mod tests { use chrono::NaiveDate; use rand::Rng; + #[test] + fn test_scalar_value_from_for_map() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + + let expected = builder.finish(); + + + let sv = ScalarValue::Map(Arc::new(expected.clone())); + let map_arr = sv.to_array().unwrap(); + let actual = as_map_array(&map_arr).unwrap(); + assert_eq!(actual, &expected); + } + #[test] fn test_scalar_value_from_for_struct() { let boolean = Arc::new(BooleanArray::from(vec![false])); diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 225bb9ddf661..e2a405595fb7 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -164,7 +164,7 @@ message Union{ repeated int32 type_ids = 3; } -// Used for List/FixedSizeList/LargeList/Struct +// Used for List/FixedSizeList/LargeList/Struct/Map message ScalarNestedValue { message Dictionary { bytes ipc_message = 1; @@ -266,6 +266,7 @@ message ScalarValue{ ScalarNestedValue list_value = 17; ScalarNestedValue fixed_size_list_value = 18; ScalarNestedValue struct_value = 32; + ScalarNestedValue map_value = 41; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index de9fede9ee86..df673de4e119 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -380,7 +380,8 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::ListValue(v) | Value::FixedSizeListValue(v) | Value::LargeListValue(v) - | Value::StructValue(v) => { + | Value::StructValue(v) + | Value::MapValue(v) => { let protobuf::ScalarNestedValue { ipc_message, arrow_data, @@ -479,6 +480,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::StructValue(_) => { Self::Struct(arr.as_struct().to_owned().into()) } + Value::MapValue(_) => Self::Map(arr.as_map().to_owned().into()), _ => unreachable!(), } } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 3cf34aeb6d01..be3cc58b23df 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -6409,6 +6409,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::StructValue(v) => { struct_ser.serialize_field("structValue", v)?; } + scalar_value::Value::MapValue(v) => { + struct_ser.serialize_field("mapValue", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } @@ -6525,6 +6528,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "fixedSizeListValue", "struct_value", "structValue", + "map_value", + "mapValue", "decimal128_value", "decimal128Value", "decimal256_value", @@ -6586,6 +6591,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { ListValue, FixedSizeListValue, StructValue, + MapValue, Decimal128Value, Decimal256Value, Date64Value, @@ -6646,6 +6652,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "listValue" | "list_value" => Ok(GeneratedField::ListValue), "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "structValue" | "struct_value" => Ok(GeneratedField::StructValue), + "mapValue" | "map_value" => Ok(GeneratedField::MapValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), @@ -6816,6 +6823,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("structValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) +; + } + GeneratedField::MapValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("mapValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::MapValue) ; } GeneratedField::Decimal128Value => { diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 57893321e665..b0674ff28d75 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -184,7 +184,7 @@ pub struct Union { #[prost(int32, repeated, tag = "3")] pub type_ids: ::prost::alloc::vec::Vec, } -/// Used for List/FixedSizeList/LargeList/Struct +/// Used for List/FixedSizeList/LargeList/Struct/Map #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarNestedValue { @@ -326,7 +326,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -380,6 +380,8 @@ pub mod scalar_value { FixedSizeListValue(super::ScalarNestedValue), #[prost(message, tag = "32")] StructValue(super::ScalarNestedValue), + #[prost(message, tag = "41")] + MapValue(super::ScalarNestedValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] @@ -581,7 +583,7 @@ pub struct CsvWriterOptions { /// Optional escape. Defaults to `'\\'` #[prost(string, tag = "10")] pub escape: ::prost::alloc::string::String, - /// Optional flag whether to double quote instead of escaping. Defaults to `true` + /// Optional flag whether to double quotes, instead of escaping. Defaults to `true` #[prost(bool, tag = "11")] pub double_quote: bool, } diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 877043f66809..705a479e0178 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -364,6 +364,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ScalarValue::Struct(arr) => { encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } + ScalarValue::Map(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) } @@ -938,7 +941,7 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } -// ScalarValue::List / FixedSizeList / LargeList / Struct are serialized using +// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using // Arrow IPC messages as a single column RecordBatch fn encode_scalar_nested_value( arr: ArrayRef, @@ -992,6 +995,9 @@ fn encode_scalar_nested_value( scalar_list_value, )), }), + ScalarValue::Map(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::MapValue(scalar_list_value)), + }), _ => unreachable!(), } } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 875fe8992e90..b0674ff28d75 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -184,7 +184,7 @@ pub struct Union { #[prost(int32, repeated, tag = "3")] pub type_ids: ::prost::alloc::vec::Vec, } -/// Used for List/FixedSizeList/LargeList/Struct +/// Used for List/FixedSizeList/LargeList/Struct/Map #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarNestedValue { @@ -326,7 +326,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -380,6 +380,8 @@ pub mod scalar_value { FixedSizeListValue(super::ScalarNestedValue), #[prost(message, tag = "32")] StructValue(super::ScalarNestedValue), + #[prost(message, tag = "41")] + MapValue(super::ScalarNestedValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 510ebe9a9801..1684757faba9 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1219,6 +1219,30 @@ fn round_trip_scalar_values() { ), ]))) .unwrap(), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, false), + ])), + false, + )), + false, + )) + .unwrap(), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, true), + ])), + false, + )), + true, + )) + .unwrap(), ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), ScalarValue::FixedSizeBinary(0, None), ScalarValue::FixedSizeBinary(5, None), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d101da8dd516..07af4bfeba9c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -952,6 +952,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), + ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } @@ -1484,7 +1485,7 @@ mod tests { data_type: DataType::Decimal128(10, -2), }), r#"CAST(a AS DECIMAL(12,0))"#, - ) + ), ]; for (expr, expected) in tests { From 584d0db924ce3381464f36b6092b6b48ce3ec9c0 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Wed, 3 Jul 2024 01:31:49 +0800 Subject: [PATCH 3/5] add display test --- datafusion/common/src/scalar/mod.rs | 50 ++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 383b12b1ee9a..2004d88b9d99 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -126,7 +126,7 @@ pub use struct_builder::ScalarStructBuilder; /// /// # Nested Types /// -/// `List` / `LargeList` / `FixedSizeList` / `Struct` are represented as a +/// `List` / `LargeList` / `FixedSizeList` / `Struct` / `Map` are represented as a /// single element array of the corresponding type. /// /// ## Example: Creating [`ScalarValue::Struct`] using [`ScalarStructBuilder`] @@ -3452,9 +3452,6 @@ impl fmt::Display for ScalarValue { )? } ScalarValue::Map(map_arr) => { - // ScalarValue Map should always have a single element - assert_eq!(map_arr.len(), 1); - if map_arr.null_count() == map_arr.len() { write!(f, "NULL")?; return Ok(()); @@ -3563,9 +3560,6 @@ impl fmt::Debug for ScalarValue { ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Struct(struct_arr) => { - // ScalarValue Struct should always have a single element - assert_eq!(struct_arr.len(), 1); - let columns = struct_arr.columns(); let fields = struct_arr.fields(); @@ -3601,7 +3595,7 @@ impl fmt::Debug for ScalarValue { let value = array_value_to_string(arr.column(1), i) .unwrap(); - format!("{}:{}", key, value) + format!("{key:?}:{value:?}") }) .collect(); format!("{{{}}}", buffer.join(",")) @@ -6346,6 +6340,46 @@ mod tests { assert_batches_eq!(&expected, &[batch]); } + #[test] + fn test_map_display() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + + let map_value = ScalarValue::Map(Arc::new(builder.finish())); + + assert_eq!(map_value.to_string(), "[{joe:1},{blogs:2,foo:4},{},NULL]"); + + let ScalarValue::Map(arr) = map_value else { + panic!("Expected map"); + }; + + //verify compared to arrow display + let batch = RecordBatch::try_from_iter(vec![("m", arr as _)]).unwrap(); + let expected = [ + "+--------------------+", + "| m |", + "+--------------------+", + "| {joe: 1} |", + "| {blogs: 2, foo: 4} |", + "| {} |", + "| |", + "+--------------------+", + ]; + assert_batches_eq!(&expected, &[batch]); + } + #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; From ef213857814c310b38301580b143a0c89b3792d3 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Wed, 3 Jul 2024 01:39:15 +0800 Subject: [PATCH 4/5] cargo fmt --- datafusion/common/src/scalar/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 2004d88b9d99..d1e9db85e18a 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -3695,7 +3695,7 @@ mod tests { use super::*; use crate::cast::{ - as_string_array, as_struct_array, as_uint32_array, as_uint64_array, as_map_array, + as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array, }; use crate::assert_batches_eq; @@ -3728,7 +3728,6 @@ mod tests { let expected = builder.finish(); - let sv = ScalarValue::Map(Arc::new(expected.clone())); let map_arr = sv.to_array().unwrap(); let actual = as_map_array(&map_arr).unwrap(); From 510d533f67733911a1463cd68ebfda75e232539c Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Wed, 3 Jul 2024 19:55:06 +0800 Subject: [PATCH 5/5] address comments and enhance tests --- datafusion/common/src/scalar/mod.rs | 12 +++++++--- .../tests/cases/roundtrip_logical_plan.rs | 24 ++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index d1e9db85e18a..55ce76c4b939 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -3560,6 +3560,9 @@ impl fmt::Debug for ScalarValue { ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Struct(struct_arr) => { + // ScalarValue Struct should always have a single element + assert_eq!(struct_arr.len(), 1); + let columns = struct_arr.columns(); let fields = struct_arr.fields(); @@ -3579,8 +3582,6 @@ impl fmt::Debug for ScalarValue { ) } ScalarValue::Map(map_arr) => { - // ScalarValue Map should always have a single element - assert_eq!(map_arr.len(), 1); write!( f, "Map([{}])", @@ -6298,6 +6299,7 @@ mod tests { .unwrap(); assert_eq!(s.to_string(), "{a:1,b:}"); + assert_eq!(format!("{s:?}"), r#"Struct({a:1,b:})"#); let ScalarValue::Struct(arr) = s else { panic!("Expected struct"); @@ -6340,7 +6342,7 @@ mod tests { } #[test] - fn test_map_display() { + fn test_map_display_and_debug() { let string_builder = StringBuilder::new(); let int_builder = Int32Builder::with_capacity(4); let mut builder = MapBuilder::new(None, string_builder, int_builder); @@ -6359,6 +6361,10 @@ mod tests { let map_value = ScalarValue::Map(Arc::new(builder.finish())); assert_eq!(map_value.to_string(), "[{joe:1},{blogs:2,foo:4},{},NULL]"); + assert_eq!( + format!("{map_value:?}"), + r#"Map([{"joe":"1"},{"blogs":"2","foo":"4"},{},NULL])"# + ); let ScalarValue::Map(arr) = map_value else { panic!("Expected map"); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d3bdf3cc7311..7304b024e3a1 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -21,7 +21,9 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use std::vec; -use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::array::{ + ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, +}; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, @@ -1270,6 +1272,7 @@ fn round_trip_scalar_values() { true, )) .unwrap(), + ScalarValue::Map(Arc::new(create_map_array_test_case())), ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), ScalarValue::FixedSizeBinary(0, None), ScalarValue::FixedSizeBinary(5, None), @@ -1292,6 +1295,25 @@ fn round_trip_scalar_values() { } } +// create a map array [{joe:1}, {blogs:2, foo:4}, {}, null] for testing +fn create_map_array_test_case() -> MapArray { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.finish() +} + #[test] fn round_trip_scalar_types() { let should_pass: Vec = vec![