Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ScalarValue::Map #11224

Merged
merged 6 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 185 additions & 5 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub use struct_builder::ScalarStructBuilder;
///
/// # Nested Types
///
/// `List` / `LargeList` / `FixedSizeList` / `Struct` are represented as a
/// `List` / `LargeList` / `FixedSizeList` / `Struct` / `Map` are represented as a
/// single element array of the corresponding type.
///
/// ## Example: Creating [`ScalarValue::Struct`] using [`ScalarStructBuilder`]
Expand Down Expand Up @@ -247,6 +247,8 @@ pub enum ScalarValue {
/// Represents a single element [`StructArray`] as an [`ArrayRef`]. See
/// [`ScalarValue`] for examples of how to create instances of this type.
Struct(Arc<StructArray>),
/// Represents a single element [`MapArray`] as an [`ArrayRef`].
Map(Arc<MapArray>),
/// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01
Date32(Option<i32>),
/// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01
Expand Down Expand Up @@ -370,6 +372,8 @@ impl PartialEq for ScalarValue {
(LargeList(_), _) => false,
(Struct(v1), Struct(v2)) => v1.eq(v2),
(Struct(_), _) => false,
(Map(v1), Map(v2)) => v1.eq(v2),
(Map(_), _) => false,
(Date32(v1), Date32(v2)) => v1.eq(v2),
(Date32(_), _) => false,
(Date64(v1), Date64(v2)) => v1.eq(v2),
Expand Down Expand Up @@ -502,6 +506,8 @@ impl PartialOrd for ScalarValue {
partial_cmp_struct(struct_arr1, struct_arr2)
}
(Struct(_), _) => None,
(Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2),
(Map(_), _) => None,
(Date32(v1), Date32(v2)) => v1.partial_cmp(v2),
(Date32(_), _) => None,
(Date64(v1), Date64(v2)) => v1.partial_cmp(v2),
Expand Down Expand Up @@ -631,6 +637,34 @@ fn partial_cmp_struct(s1: &Arc<StructArray>, s2: &Arc<StructArray>) -> Option<Or
Some(Ordering::Equal)
}

fn partial_cmp_map(m1: &Arc<MapArray>, m2: &Arc<MapArray>) -> Option<Ordering> {
if m1.len() != m2.len() {
return None;
}

if m1.data_type() != m2.data_type() {
return None;
}

for col_index in 0..m1.len() {
let arr1 = m1.entries().column(col_index);
let arr2 = m2.entries().column(col_index);

let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?;
let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
}
}
Some(Ordering::Equal)
}

impl Eq for ScalarValue {}

//Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper
Expand Down Expand Up @@ -696,6 +730,9 @@ impl std::hash::Hash for ScalarValue {
Struct(arr) => {
hash_nested_array(arr.to_owned() as ArrayRef, state);
}
Map(arr) => {
hash_nested_array(arr.to_owned() as ArrayRef, state);
}
Date32(v) => v.hash(state),
Date64(v) => v.hash(state),
Time32Second(v) => v.hash(state),
Expand Down Expand Up @@ -1132,6 +1169,7 @@ impl ScalarValue {
ScalarValue::LargeList(arr) => arr.data_type().to_owned(),
ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(),
ScalarValue::Struct(arr) => arr.data_type().to_owned(),
ScalarValue::Map(arr) => arr.data_type().to_owned(),
ScalarValue::Date32(_) => DataType::Date32,
ScalarValue::Date64(_) => DataType::Date64,
ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second),
Expand Down Expand Up @@ -1403,6 +1441,7 @@ impl ScalarValue {
ScalarValue::LargeList(arr) => arr.len() == arr.null_count(),
ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(),
ScalarValue::Struct(arr) => arr.len() == arr.null_count(),
ScalarValue::Map(arr) => arr.len() == arr.null_count(),
ScalarValue::Date32(v) => v.is_none(),
ScalarValue::Date64(v) => v.is_none(),
ScalarValue::Time32Second(v) => v.is_none(),
Expand Down Expand Up @@ -2172,6 +2211,9 @@ impl ScalarValue {
ScalarValue::Struct(arr) => {
Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)?
}
ScalarValue::Map(arr) => {
Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)?
}
ScalarValue::Date32(e) => {
build_array_from_option!(Date32, Date32Array, e, size)
}
Expand Down Expand Up @@ -2802,6 +2844,9 @@ impl ScalarValue {
ScalarValue::Struct(arr) => {
Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index)
}
ScalarValue::Map(arr) => {
Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index)
}
ScalarValue::Date32(val) => {
eq_array_primitive!(array, index, Date32Array, val)?
}
Expand Down Expand Up @@ -2937,6 +2982,7 @@ impl ScalarValue {
ScalarValue::LargeList(arr) => arr.get_array_memory_size(),
ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
ScalarValue::Struct(arr) => arr.get_array_memory_size(),
ScalarValue::Map(arr) => arr.get_array_memory_size(),
ScalarValue::Union(vals, fields, _mode) => {
vals.as_ref()
.map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv))
Expand Down Expand Up @@ -3269,6 +3315,12 @@ impl TryFrom<&DataType> for ScalarValue {
.to_owned()
.into(),
),
DataType::Map(fields, sorted) => ScalarValue::Map(
new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1)
.as_map()
.to_owned()
.into(),
),
DataType::Union(fields, mode) => {
ScalarValue::Union(None, fields.clone(), *mode)
}
Expand Down Expand Up @@ -3399,6 +3451,43 @@ impl fmt::Display for ScalarValue {
.join(",")
)?
}
ScalarValue::Map(map_arr) => {
if map_arr.null_count() == map_arr.len() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also have the assert like Struct?

// ScalarValue Map should always have a single element
assert_eq!(map_arr.len(), 1);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also have the assert like Struct?

// ScalarValue Map should always have a single element
assert_eq!(map_arr.len(), 1);

No, MapArray is StructArray. It could contain more than one element.

or if not, then maybe https://github.com/apache/datafusion/pull/11224/files#diff-49e275af8f09685c7bbc491db8ab3b9479960878f42ac558ec0e3e39570590bdR3583 shoulnd't have it either ? 😅

I forgot to remove this. Thanks for reminding me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add some tests to cover the Debug function in the latest commit.

write!(f, "NULL")?;
return Ok(());
}

write!(
f,
"[{}]",
map_arr
.iter()
.map(|struct_array| {
if let Some(arr) = struct_array {
let mut buffer = VecDeque::new();
for i in 0..arr.len() {
let key =
array_value_to_string(arr.column(0), i).unwrap();
let value =
array_value_to_string(arr.column(1), i).unwrap();
buffer.push_back(format!("{}:{}", key, value));
}
format!(
"{{{}}}",
buffer
.into_iter()
.collect::<Vec<_>>()
.join(",")
.as_str()
)
} else {
"NULL".to_string()
}
})
.collect::<Vec<_>>()
.join(",")
)?
}
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "{}:{}", id, val)?,
None => write!(f, "NULL")?,
Expand Down Expand Up @@ -3471,9 +3560,6 @@ impl fmt::Debug for ScalarValue {
ScalarValue::List(_) => write!(f, "List({self})"),
ScalarValue::LargeList(_) => write!(f, "LargeList({self})"),
ScalarValue::Struct(struct_arr) => {
// ScalarValue Struct should always have a single element
assert_eq!(struct_arr.len(), 1);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this removal intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, I removed the wrong lines. I'll revert it. Thanks.

let columns = struct_arr.columns();
let fields = struct_arr.fields();

Expand All @@ -3492,6 +3578,35 @@ impl fmt::Debug for ScalarValue {
.join(",")
)
}
ScalarValue::Map(map_arr) => {
// ScalarValue Map should always have a single element
assert_eq!(map_arr.len(), 1);
write!(
f,
"Map([{}])",
map_arr
.iter()
.map(|struct_array| {
if let Some(arr) = struct_array {
let buffer: Vec<String> = (0..arr.len())
.map(|i| {
let key = array_value_to_string(arr.column(0), i)
.unwrap();
let value =
array_value_to_string(arr.column(1), i)
.unwrap();
format!("{key:?}:{value:?}")
})
.collect();
format!("{{{}}}", buffer.join(","))
} else {
"NULL".to_string()
}
})
.collect::<Vec<_>>()
.join(",")
)
}
ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"),
ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"),
ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"),
Expand Down Expand Up @@ -3580,7 +3695,7 @@ mod tests {

use super::*;
use crate::cast::{
as_string_array, as_struct_array, as_uint32_array, as_uint64_array,
as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array,
};

use crate::assert_batches_eq;
Expand All @@ -3594,6 +3709,31 @@ mod tests {
use chrono::NaiveDate;
use rand::Rng;

#[test]
fn test_scalar_value_from_for_map() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::with_capacity(4);
let mut builder = MapBuilder::new(None, string_builder, int_builder);
builder.keys().append_value("joe");
builder.values().append_value(1);
builder.append(true).unwrap();

builder.keys().append_value("blogs");
builder.values().append_value(2);
builder.keys().append_value("foo");
builder.values().append_value(4);
builder.append(true).unwrap();
builder.append(true).unwrap();
builder.append(false).unwrap();

let expected = builder.finish();

let sv = ScalarValue::Map(Arc::new(expected.clone()));
let map_arr = sv.to_array().unwrap();
let actual = as_map_array(&map_arr).unwrap();
assert_eq!(actual, &expected);
}

#[test]
fn test_scalar_value_from_for_struct() {
let boolean = Arc::new(BooleanArray::from(vec![false]));
Expand Down Expand Up @@ -6199,6 +6339,46 @@ mod tests {
assert_batches_eq!(&expected, &[batch]);
}

#[test]
fn test_map_display() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::with_capacity(4);
let mut builder = MapBuilder::new(None, string_builder, int_builder);
builder.keys().append_value("joe");
builder.values().append_value(1);
builder.append(true).unwrap();

builder.keys().append_value("blogs");
builder.values().append_value(2);
builder.keys().append_value("foo");
builder.values().append_value(4);
builder.append(true).unwrap();
builder.append(true).unwrap();
builder.append(false).unwrap();

let map_value = ScalarValue::Map(Arc::new(builder.finish()));

assert_eq!(map_value.to_string(), "[{joe:1},{blogs:2,foo:4},{},NULL]");

let ScalarValue::Map(arr) = map_value else {
panic!("Expected map");
};

//verify compared to arrow display
let batch = RecordBatch::try_from_iter(vec![("m", arr as _)]).unwrap();
let expected = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"+--------------------+",
"| m |",
"+--------------------+",
"| {joe: 1} |",
"| {blogs: 2, foo: 4} |",
"| {} |",
"| |",
"+--------------------+",
];
assert_batches_eq!(&expected, &[batch]);
}

#[test]
fn test_build_timestamp_millisecond_list() {
let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)];
Expand Down
3 changes: 2 additions & 1 deletion datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ message Union{
repeated int32 type_ids = 3;
}

// Used for List/FixedSizeList/LargeList/Struct
// Used for List/FixedSizeList/LargeList/Struct/Map
message ScalarNestedValue {
message Dictionary {
bytes ipc_message = 1;
Expand Down Expand Up @@ -266,6 +266,7 @@ message ScalarValue{
ScalarNestedValue list_value = 17;
ScalarNestedValue fixed_size_list_value = 18;
ScalarNestedValue struct_value = 32;
ScalarNestedValue map_value = 41;

Decimal128 decimal128_value = 20;
Decimal256 decimal256_value = 39;
Expand Down
4 changes: 3 additions & 1 deletion datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::ListValue(v)
| Value::FixedSizeListValue(v)
| Value::LargeListValue(v)
| Value::StructValue(v) => {
| Value::StructValue(v)
| Value::MapValue(v) => {
let protobuf::ScalarNestedValue {
ipc_message,
arrow_data,
Expand Down Expand Up @@ -479,6 +480,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::StructValue(_) => {
Self::Struct(arr.as_struct().to_owned().into())
}
Value::MapValue(_) => Self::Map(arr.as_map().to_owned().into()),
_ => unreachable!(),
}
}
Expand Down
14 changes: 14 additions & 0 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6409,6 +6409,9 @@ impl serde::Serialize for ScalarValue {
scalar_value::Value::StructValue(v) => {
struct_ser.serialize_field("structValue", v)?;
}
scalar_value::Value::MapValue(v) => {
struct_ser.serialize_field("mapValue", v)?;
}
scalar_value::Value::Decimal128Value(v) => {
struct_ser.serialize_field("decimal128Value", v)?;
}
Expand Down Expand Up @@ -6525,6 +6528,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"fixedSizeListValue",
"struct_value",
"structValue",
"map_value",
"mapValue",
"decimal128_value",
"decimal128Value",
"decimal256_value",
Expand Down Expand Up @@ -6586,6 +6591,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
ListValue,
FixedSizeListValue,
StructValue,
MapValue,
Decimal128Value,
Decimal256Value,
Date64Value,
Expand Down Expand Up @@ -6646,6 +6652,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"listValue" | "list_value" => Ok(GeneratedField::ListValue),
"fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue),
"structValue" | "struct_value" => Ok(GeneratedField::StructValue),
"mapValue" | "map_value" => Ok(GeneratedField::MapValue),
"decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value),
"decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value),
"date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value),
Expand Down Expand Up @@ -6816,6 +6823,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
return Err(serde::de::Error::duplicate_field("structValue"));
}
value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue)
;
}
GeneratedField::MapValue => {
if value__.is_some() {
return Err(serde::de::Error::duplicate_field("mapValue"));
}
value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::MapValue)
;
}
GeneratedField::Decimal128Value => {
Expand Down
Loading