Skip to content

Commit

Permalink
Support try_from_array and eq_array for ScalarValue::Union (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
joroKr21 authored Sep 2, 2024
1 parent 8a0ca9b commit 0ef8a63
Showing 1 changed file with 126 additions and 4 deletions.
130 changes: 126 additions & 4 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,10 @@ impl ScalarValue {
ScalarValue::DurationMillisecond(v) => v.is_none(),
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
ScalarValue::Union(v, _, _) => v.is_none(),
ScalarValue::Union(v, _, _) => match v {
Some((_, v)) => v.is_null(),
None => true,
},
ScalarValue::Dictionary(_, v) => v.is_null(),
}
}
Expand Down Expand Up @@ -2500,7 +2503,13 @@ impl ScalarValue {
DataType::Duration(TimeUnit::Nanosecond) => {
typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)?
}

DataType::Union(fields, mode) => {
let array = as_union_array(array);
let ti = array.type_id(index);
let index = array.value_offset(index);
let value = ScalarValue::try_from_array(array.child(ti), index)?;
ScalarValue::Union(Some((ti, Box::new(value))), fields.clone(), *mode)
}
other => {
return _not_impl_err!(
"Can't create a scalar from array of type \"{other:?}\""
Expand Down Expand Up @@ -2727,8 +2736,15 @@ impl ScalarValue {
ScalarValue::DurationNanosecond(val) => {
eq_array_primitive!(array, index, DurationNanosecondArray, val)?
}
ScalarValue::Union(_, _, _) => {
return _not_impl_err!("Union is not supported yet")
ScalarValue::Union(value, _, _) => {
let array = as_union_array(array);
let ti = array.type_id(index);
let index = array.value_offset(index);
if let Some((ti_v, value)) = value {
ti_v == &ti && value.eq_array(array.child(ti), index)?
} else {
array.child(ti).is_null(index)
}
}
ScalarValue::Dictionary(key_type, v) => {
let (values_array, values_index) = match key_type.as_ref() {
Expand Down Expand Up @@ -5092,6 +5108,112 @@ mod tests {
assert_eq!(&array, &expected);
}

#[test]
fn test_scalar_union_sparse() {
let field_a = Arc::new(Field::new("A", DataType::Int32, true));
let field_b = Arc::new(Field::new("B", DataType::Boolean, true));
let field_c = Arc::new(Field::new("C", DataType::Utf8, true));
let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]);

let mut values_a = vec![None; 6];
values_a[0] = Some(42);
let mut values_b = vec![None; 6];
values_b[1] = Some(true);
let mut values_c = vec![None; 6];
values_c[2] = Some("foo");
let children: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from(values_a)),
Arc::new(BooleanArray::from(values_b)),
Arc::new(StringArray::from(values_c)),
];

let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]);
let array: ArrayRef = Arc::new(
UnionArray::try_new(fields.clone(), type_ids, None, children)
.expect("UnionArray"),
);

let expected = [
(0, ScalarValue::from(42)),
(1, ScalarValue::from(true)),
(2, ScalarValue::from("foo")),
(0, ScalarValue::Int32(None)),
(1, ScalarValue::Boolean(None)),
(2, ScalarValue::Utf8(None)),
];

for (i, (ti, value)) in expected.into_iter().enumerate() {
let is_null = value.is_null();
let value = Some((ti, Box::new(value)));
let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Sparse);
let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array");

assert_eq!(
actual, expected,
"[{i}] {actual} was not equal to {expected}"
);

assert!(
expected.eq_array(&array, i).expect("eq_array"),
"[{i}] {expected}.eq_array was false"
);

if is_null {
assert!(actual.is_null(), "[{i}] {actual} was not null")
}
}
}

#[test]
fn test_scalar_union_dense() {
let field_a = Arc::new(Field::new("A", DataType::Int32, true));
let field_b = Arc::new(Field::new("B", DataType::Boolean, true));
let field_c = Arc::new(Field::new("C", DataType::Utf8, true));
let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]);
let children: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from(vec![Some(42), None])),
Arc::new(BooleanArray::from(vec![Some(true), None])),
Arc::new(StringArray::from(vec![Some("foo"), None])),
];

let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]);
let offsets = ScalarBuffer::from(vec![0, 0, 0, 1, 1, 1]);
let array: ArrayRef = Arc::new(
UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)
.expect("UnionArray"),
);

let expected = [
(0, ScalarValue::from(42)),
(1, ScalarValue::from(true)),
(2, ScalarValue::from("foo")),
(0, ScalarValue::Int32(None)),
(1, ScalarValue::Boolean(None)),
(2, ScalarValue::Utf8(None)),
];

for (i, (ti, value)) in expected.into_iter().enumerate() {
let is_null = value.is_null();
let value = Some((ti, Box::new(value)));
let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Dense);
let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array");

assert_eq!(
actual, expected,
"[{i}] {actual} was not equal to {expected}"
);

assert!(
expected.eq_array(&array, i).expect("eq_array"),
"[{i}] {expected}.eq_array was false"
);

if is_null {
assert!(actual.is_null(), "[{i}] {actual} was not null")
}
}
}

#[test]
fn test_lists_in_struct() {
let field_a = Arc::new(Field::new("A", DataType::Utf8, false));
Expand Down

0 comments on commit 0ef8a63

Please sign in to comment.