Skip to content

Commit

Permalink
Fix ScalarValue handling of NULL values for ListArray (#7969)
Browse files Browse the repository at this point in the history
* Fix try_from_array data type for NULL value in ListArray

* Fix

* Explicitly assert the datatype

* For review
  • Loading branch information
viirya authored Oct 30, 2023
1 parent bb1d7f9 commit 448dff5
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 25 deletions.
125 changes: 100 additions & 25 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,10 +1312,11 @@ impl ScalarValue {
Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>(
scalars.into_iter().map(|x| match x {
ScalarValue::List(arr) => {
if arr.as_any().downcast_ref::<NullArray>().is_some() {
// `ScalarValue::List` contains a single element `ListArray`.
let list_arr = as_list_array(&arr);
if list_arr.is_null(0) {
None
} else {
let list_arr = as_list_array(&arr);
let primitive_arr =
list_arr.values().as_primitive::<$ARRAY_TY>();
Some(
Expand All @@ -1339,12 +1340,14 @@ impl ScalarValue {
for scalar in scalars.into_iter() {
match scalar {
ScalarValue::List(arr) => {
if arr.as_any().downcast_ref::<NullArray>().is_some() {
// `ScalarValue::List` contains a single element `ListArray`.
let list_arr = as_list_array(&arr);

if list_arr.is_null(0) {
builder.append(false);
continue;
}

let list_arr = as_list_array(&arr);
let string_arr = $STRING_ARRAY(list_arr.values());

for v in string_arr.iter() {
Expand Down Expand Up @@ -1699,15 +1702,16 @@ impl ScalarValue {

for scalar in scalars {
if let ScalarValue::List(arr) = scalar {
// i.e. NullArray(1)
if arr.as_any().downcast_ref::<NullArray>().is_some() {
// `ScalarValue::List` contains a single element `ListArray`.
let list_arr = as_list_array(&arr);

if list_arr.is_null(0) {
// Repeat previous offset index
offsets.push(0);

// Element is null
valid.append(false);
} else {
let list_arr = as_list_array(&arr);
let arr = list_arr.values().to_owned();
offsets.push(arr.len());
elements.push(arr);
Expand Down Expand Up @@ -2234,28 +2238,20 @@ impl ScalarValue {
}
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
DataType::List(nested_type) => {
DataType::List(_) => {
let list_array = as_list_array(array);
let arr = match list_array.is_null(index) {
true => new_null_array(nested_type.data_type(), 0),
false => {
let nested_array = list_array.value(index);
Arc::new(wrap_into_list_array(nested_array))
}
};
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));

ScalarValue::List(arr)
}
// TODO: There is no test for FixedSizeList now, add it later
DataType::FixedSizeList(nested_type, _len) => {
DataType::FixedSizeList(_, _) => {
let list_array = as_fixed_size_list_array(array)?;
let arr = match list_array.is_null(index) {
true => new_null_array(nested_type.data_type(), 0),
false => {
let nested_array = list_array.value(index);
Arc::new(wrap_into_list_array(nested_array))
}
};
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));

ScalarValue::List(arr)
}
Expand Down Expand Up @@ -2944,8 +2940,15 @@ impl TryFrom<&DataType> for ScalarValue {
index_type.clone(),
Box::new(value_type.as_ref().try_into()?),
),
DataType::List(_) => ScalarValue::List(new_null_array(&DataType::Null, 0)),

// `ScalaValue::List` contains single element `ListArray`.
DataType::List(field) => ScalarValue::List(new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
field.data_type().clone(),
true,
))),
1,
)),
DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()),
DataType::Null => ScalarValue::Null,
_ => {
Expand Down Expand Up @@ -3885,6 +3888,78 @@ mod tests {
);
}

#[test]
fn scalar_try_from_array_list_array_null() {
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
None,
]);

let non_null_list_scalar = ScalarValue::try_from_array(&list, 0).unwrap();
let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap();

let data_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));

assert_eq!(non_null_list_scalar.data_type(), data_type.clone());
assert_eq!(null_list_scalar.data_type(), data_type);
}

#[test]
fn scalar_try_from_list() {
let data_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
let data_type = &data_type;
let scalar: ScalarValue = data_type.try_into().unwrap();

let expected = ScalarValue::List(new_null_array(
&DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
1,
));

assert_eq!(expected, scalar)
}

#[test]
fn scalar_try_from_list_of_list() {
let data_type = DataType::List(Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
)));
let data_type = &data_type;
let scalar: ScalarValue = data_type.try_into().unwrap();

let expected = ScalarValue::List(new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
))),
1,
));

assert_eq!(expected, scalar)
}

#[test]
fn scalar_try_from_not_equal_list_nested_list() {
let list_data_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
let data_type = &list_data_type;
let list_scalar: ScalarValue = data_type.try_into().unwrap();

let nested_list_data_type = DataType::List(Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
)));
let data_type = &nested_list_data_type;
let nested_list_scalar: ScalarValue = data_type.try_into().unwrap();

assert_ne!(list_scalar, nested_list_scalar);
}

#[test]
fn scalar_try_from_dict_datatype() {
let data_type =
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ AS VALUES
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10)
;

query TTT
select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays;
----
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })

# arrays table
query ???
select column1, column2, column3 from arrays;
Expand Down

0 comments on commit 448dff5

Please sign in to comment.