From 856f14bbbef8d4ca35d6bb7810568337dbb98d63 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Wed, 27 Dec 2023 17:01:58 -0500 Subject: [PATCH 1/8] chore: iterative scalar procesing --- crates/arrow_util/src/cast.rs | 201 ++ crates/datafusion_ext/src/cast.rs | 2489 +++++++++++++++++ crates/datafusion_ext/src/lib.rs | 1 + crates/sqlbuiltins/src/errors.rs | 16 + .../sqlbuiltins/src/functions/scalars/mod.rs | 19 +- 5 files changed, 2715 insertions(+), 11 deletions(-) create mode 100644 crates/arrow_util/src/cast.rs create mode 100644 crates/datafusion_ext/src/cast.rs diff --git a/crates/arrow_util/src/cast.rs b/crates/arrow_util/src/cast.rs new file mode 100644 index 000000000..96e9002e2 --- /dev/null +++ b/crates/arrow_util/src/cast.rs @@ -0,0 +1,201 @@ +use datafusion::arrow::array::{Array, ArrayRef, Decimal128Array}; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::error::ArrowError; +use datafusion::scalar::ScalarValue; + +pub fn try_cast( + array: &dyn Array, + op: &dyn Fn(ScalarValue) -> Result, +) -> Result { + Ok(match array.data_type() { + // DataType::Null => ScalarValue::iter_to_array( + // (0..array.len()) + // .map(|_| ScalarValue::Null) + // .map(op) + // .collect()?, + // )?, + DataType::Decimal128(precision, scale) => ScalarValue::iter_to_array( + array + .as_any() + .downcast_ref::() + .ok_or(ArrowError::CastError("decimal128".to_string()))? + .iter() + .map(|v| ScalarValue::Decimal128(v, *precision, *scale)) + .map(op) + .)()?, + )?, + // DataType::Decimal256(precision, scale) => { + // ScalarValue::get_decimal_value_from_array(array, index, *precision, *scale)? + // } + // DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), + // DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), + // DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), + // DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), + // DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), + // DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), + // DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), + // DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), + // DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), + // DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), + // DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + // DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + // DataType::LargeBinary => { + // typed_cast!(array, index, LargeBinaryArray, LargeBinary) + // } + // DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), + // DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), + // DataType::List(nested_type) => { + // let list_array = as_list_array(array)?; + // let value = match list_array.is_null(index) { + // true => None, + // false => { + // let nested_array = list_array.value(index); + // let scalar_vec = (0..nested_array.len()) + // .map(|i| ScalarValue::try_from_array(&nested_array, i)) + // .collect::>>()?; + // Some(scalar_vec) + // } + // }; + // ScalarValue::new_list(value, nested_type.data_type().clone()) + // } + // DataType::Date32 => { + // typed_cast!(array, index, Date32Array, Date32) + // } + // DataType::Date64 => { + // typed_cast!(array, index, Date64Array, Date64) + // } + // DataType::Time32(TimeUnit::Second) => { + // typed_cast!(array, index, Time32SecondArray, Time32Second) + // } + // DataType::Time32(TimeUnit::Millisecond) => { + // typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond) + // } + // DataType::Time64(TimeUnit::Microsecond) => { + // typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond) + // } + // DataType::Time64(TimeUnit::Nanosecond) => { + // typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond) + // } + // DataType::Timestamp(TimeUnit::Second, tz_opt) => { + // typed_cast_tz!(array, index, TimestampSecondArray, TimestampSecond, tz_opt) + // } + // DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + // typed_cast_tz!( + // array, + // index, + // TimestampMillisecondArray, + // TimestampMillisecond, + // tz_opt + // ) + // } + // DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + // typed_cast_tz!( + // array, + // index, + // TimestampMicrosecondArray, + // TimestampMicrosecond, + // tz_opt + // ) + // } + // DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + // typed_cast_tz!( + // array, + // index, + // TimestampNanosecondArray, + // TimestampNanosecond, + // tz_opt + // ) + // } + // DataType::Dictionary(key_type, _) => { + // let (values_array, values_index) = match key_type.as_ref() { + // DataType::Int8 => get_dict_value::(array, index), + // DataType::Int16 => get_dict_value::(array, index), + // DataType::Int32 => get_dict_value::(array, index), + // DataType::Int64 => get_dict_value::(array, index), + // DataType::UInt8 => get_dict_value::(array, index), + // DataType::UInt16 => get_dict_value::(array, index), + // DataType::UInt32 => get_dict_value::(array, index), + // DataType::UInt64 => get_dict_value::(array, index), + // _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + // }; + // // look up the index in the values dictionary + // let value = match values_index { + // Some(values_index) => ScalarValue::try_from_array(values_array, values_index), + // // else entry was null, so return null + // None => values_array.data_type().try_into(), + // }?; + + // Self::Dictionary(key_type.clone(), Box::new(value)) + // } + // DataType::Struct(fields) => { + // let array = as_struct_array(array)?; + // let mut field_values: Vec = Vec::new(); + // for col_index in 0..array.num_columns() { + // let col_array = array.column(col_index); + // let col_scalar = ScalarValue::try_from_array(col_array, index)?; + // field_values.push(col_scalar); + // } + // Self::Struct(Some(field_values), fields.clone()) + // } + // DataType::FixedSizeList(nested_type, _len) => { + // let list_array = as_fixed_size_list_array(array)?; + // let value = match list_array.is_null(index) { + // true => None, + // false => { + // let nested_array = list_array.value(index); + // let scalar_vec = (0..nested_array.len()) + // .map(|i| ScalarValue::try_from_array(&nested_array, i)) + // .collect::>>()?; + // Some(scalar_vec) + // } + // }; + // ScalarValue::new_list(value, nested_type.data_type().clone()) + // } + // DataType::FixedSizeBinary(_) => { + // let array = as_fixed_size_binary_array(array)?; + // let size = match array.data_type() { + // DataType::FixedSizeBinary(size) => *size, + // _ => unreachable!(), + // }; + // ScalarValue::FixedSizeBinary( + // size, + // match array.is_null(index) { + // true => None, + // false => Some(array.value(index).into()), + // }, + // ) + // } + // DataType::Interval(IntervalUnit::DayTime) => { + // typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime) + // } + // DataType::Interval(IntervalUnit::YearMonth) => { + // typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth) + // } + // DataType::Interval(IntervalUnit::MonthDayNano) => { + // typed_cast!( + // array, + // index, + // IntervalMonthDayNanoArray, + // IntervalMonthDayNano + // ) + // } + + // DataType::Duration(TimeUnit::Second) => { + // typed_cast!(array, index, DurationSecondArray, DurationSecond) + // } + // DataType::Duration(TimeUnit::Millisecond) => { + // typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond) + // } + // DataType::Duration(TimeUnit::Microsecond) => { + // typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond) + // } + // DataType::Duration(TimeUnit::Nanosecond) => { + // typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond) + // } + other => { + return Err(ArrowError::CastError( + format!("Can't create a scalar from array of type \"{other:?}\"",).to_string(), + )); + } + }) +} diff --git a/crates/datafusion_ext/src/cast.rs b/crates/datafusion_ext/src/cast.rs new file mode 100644 index 000000000..710c6a402 --- /dev/null +++ b/crates/datafusion_ext/src/cast.rs @@ -0,0 +1,2489 @@ +use datafusion::arrow::array::{ + new_empty_array, new_null_array, Array, ArrayDataBuilder, ArrayRef, BinaryArray, BooleanArray, + BooleanBufferBuilder, BooleanBuilder, Date32Array, Date64Array, Decimal128Array, + Decimal256Array, DictionaryArray, FixedSizeBinaryArray, Float32Array, Float64Array, + GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, + LargeStringBuilder, ListArray, ListBuilder, PrimitiveArray, StringArray, StringBuilder, + StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; +use datafusion::arrow::buffer::NullBuffer; +use datafusion::arrow::compute::nullif; +use datafusion::arrow::datatypes::ArrowNativeType; +use datafusion::arrow::datatypes::{ + ArrowDictionaryKeyType, DataType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalUnit, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion::error::DataFusionError; +use datafusion::scalar::ScalarValue; +use std::sync::Arc; + +use crate::errors::ExtensionError; + +pub fn scalar_iter_to_array( + data_type: &DataType, + scalars: impl IntoIterator>, +) -> Result { + let scalars = scalars.into_iter(); + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for primitive types + macro_rules! build_array_primitive { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + let sv = sv?; + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + todo!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, + sv + ) + } + }) + .collect::>()?; + Arc::new(array) + } + }}; + } + + macro_rules! build_array_primitive_tz { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ + { + let array = scalars + .map(|sv| { + let sv = sv?; + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + todo!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, + sv + ) + } + }) + .collect::>()?; + Arc::new(array.with_timezone_opt($TZ.clone())) + } + }}; + } + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for "string-like" types. + macro_rules! build_array_string { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + let sv = sv?; + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + return Err(ExtensionError::String(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv, + ))); + } + }) + .collect::>()?; + Arc::new(array) + } + }}; + } + + macro_rules! build_array_list_primitive { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ + Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( + scalars.into_iter().collect::,_>>()?.into_iter().map(|x| match x { + ScalarValue::List(xs, _) => xs.map(|x| { + x.iter() + .map(|x| match x { + ScalarValue::$SCALAR_TY(i) => *i, + sv => panic!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ), + }) + .collect::>>() + }), + sv => panic!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ), + }), + )) + }}; + } + + macro_rules! build_array_list_string { + ($BUILDER:ident, $SCALAR_TY:ident) => {{ + let mut builder = ListBuilder::new($BUILDER::new()); + for scalar in scalars.into_iter() { + match scalar? { + ScalarValue::List(Some(xs), _) => { + for s in xs { + match s { + ScalarValue::$SCALAR_TY(Some(val)) => { + builder.values().append_value(val); + } + ScalarValue::$SCALAR_TY(None) => { + builder.values().append_null(); + } + sv => { + todo!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ) + } + } + } + builder.append(true); + } + ScalarValue::List(None, _) => { + builder.append(false); + } + sv => { + return Err(ExtensionError::String(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))); + } + } + } + Arc::new(builder.finish()) + }}; + } + + let array: ArrayRef = match &data_type { + DataType::Decimal128(precision, scale) => { + let decimal_array = iter_to_decimal_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) + } + DataType::Decimal256(precision, scale) => { + let decimal_array = iter_to_decimal256_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) + } + DataType::Null => iter_to_null_array(scalars), + DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), + DataType::Float32 => build_array_primitive!(Float32Array, Float32), + DataType::Float64 => build_array_primitive!(Float64Array, Float64), + DataType::Int8 => build_array_primitive!(Int8Array, Int8), + DataType::Int16 => build_array_primitive!(Int16Array, Int16), + DataType::Int32 => build_array_primitive!(Int32Array, Int32), + DataType::Int64 => build_array_primitive!(Int64Array, Int64), + DataType::UInt8 => build_array_primitive!(UInt8Array, UInt8), + DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), + DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), + DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), + DataType::Utf8 => build_array_string!(StringArray, Utf8), + DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), + DataType::Binary => build_array_string!(BinaryArray, Binary), + DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), + DataType::Date32 => build_array_primitive!(Date32Array, Date32), + DataType::Date64 => build_array_primitive!(Date64Array, Date64), + DataType::Time32(TimeUnit::Second) => { + build_array_primitive!(Time32SecondArray, Time32Second) + } + DataType::Time32(TimeUnit::Millisecond) => { + build_array_primitive!(Time32MillisecondArray, Time32Millisecond) + } + DataType::Time64(TimeUnit::Microsecond) => { + build_array_primitive!(Time64MicrosecondArray, Time64Microsecond) + } + DataType::Time64(TimeUnit::Nanosecond) => { + build_array_primitive!(Time64NanosecondArray, Time64Nanosecond) + } + DataType::Timestamp(TimeUnit::Second, tz) => { + build_array_primitive_tz!(TimestampSecondArray, TimestampSecond, tz) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + build_array_primitive_tz!(TimestampMillisecondArray, TimestampMillisecond, tz) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + build_array_primitive_tz!(TimestampMicrosecondArray, TimestampMicrosecond, tz) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + build_array_primitive_tz!(TimestampNanosecondArray, TimestampNanosecond, tz) + } + DataType::Interval(IntervalUnit::DayTime) => { + build_array_primitive!(IntervalDayTimeArray, IntervalDayTime) + } + DataType::Interval(IntervalUnit::YearMonth) => { + build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) + } + DataType::List(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list_primitive!(Int8Type, Int8, i8) + } + DataType::List(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list_primitive!(Int16Type, Int16, i16) + } + DataType::List(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list_primitive!(Int32Type, Int32, i32) + } + DataType::List(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list_primitive!(Int64Type, Int64, i64) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list_primitive!(UInt8Type, UInt8, u8) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list_primitive!(UInt16Type, UInt16, u16) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list_primitive!(UInt32Type, UInt32, u32) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list_primitive!(UInt64Type, UInt64, u64) + } + DataType::List(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list_primitive!(Float32Type, Float32, f32) + } + DataType::List(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list_primitive!(Float64Type, Float64, f64) + } + DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { + build_array_list_string!(StringBuilder, Utf8) + } + DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { + build_array_list_string!(LargeStringBuilder, LargeUtf8) + } + DataType::List(_) => { + // Fallback case handling homogeneous lists with any ScalarValue element type + let list_array = iter_to_array_list(scalars, &data_type)?; + Arc::new(list_array) + } + DataType::Struct(fields) => { + // Initialize a Vector to store the ScalarValues for each column + let mut columns: Vec> = + (0..fields.len()).map(|_| Vec::new()).collect(); + + // null mask + let mut null_mask_builder = BooleanBuilder::new(); + + // Iterate over scalars to populate the column scalars for each row + for scalar in scalars { + let scalar = scalar?; + if let ScalarValue::Struct(values, fields) = scalar { + match values { + Some(values) => { + // Push value for each field + for (column, value) in columns.iter_mut().zip(values) { + column.push(value.clone()); + } + null_mask_builder.append_value(false); + } + None => { + // Push NULL of the appropriate type for each field + for (column, field) in columns.iter_mut().zip(fields.as_ref()) { + column.push(ScalarValue::try_from(field.data_type())?); + } + null_mask_builder.append_value(true); + } + }; + } else { + return Err(ExtensionError::String(format!( + "Expected Struct but found: {scalar}" + ))); + }; + } + + // Call iter_to_array recursively to convert the scalars for each column into Arrow arrays + let field_values = fields + .iter() + .zip(columns) + .map(|(field, column)| Ok((field.clone(), ScalarValue::iter_to_array(column)?))) + .collect::, ExtensionError>>()?; + + let array = StructArray::from(field_values); + nullif(&array, &null_mask_builder.finish())? + } + DataType::Dictionary(key_type, value_type) => { + // create the values array + let value_scalars = scalars + .map(|scalar| +{let scalar = scalar?; +match scalar { + ScalarValue::Dictionary(inner_key_type, scalar) => { + if &inner_key_type == key_type { + Ok(*scalar) + } else { + panic!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})"); + } + } + _ => { + Err(ExtensionError::String(format!( + "Expected scalar of type {value_type} but found: {} {:?}", scalar,scalar + ))) + } + }}) + .collect::, ExtensionError>>()?; + + let values = ScalarValue::iter_to_array(value_scalars)?; + assert_eq!(values.data_type(), value_type.as_ref()); + + match key_type.as_ref() { + DataType::Int8 => dict_from_values::(values)?, + DataType::Int16 => dict_from_values::(values)?, + DataType::Int32 => dict_from_values::(values)?, + DataType::Int64 => dict_from_values::(values)?, + DataType::UInt8 => dict_from_values::(values)?, + DataType::UInt16 => dict_from_values::(values)?, + DataType::UInt32 => dict_from_values::(values)?, + DataType::UInt64 => dict_from_values::(values)?, + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + } + } + DataType::FixedSizeBinary(size) => { + let array = scalars + .map(|sv| { + let sv = sv?; + if let ScalarValue::FixedSizeBinary(_, v) = sv { + Ok(v) + } else { + return Err(ExtensionError::String(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {data_type:?}, got {sv:?}" + ))); + } + }) + .collect::, ExtensionError>>()?; + let array = + FixedSizeBinaryArray::try_from_sparse_iter_with_size(array.into_iter(), *size)?; + Arc::new(array) + } + // explicitly enumerate unsupported types so newly added + // types must be aknowledged, Time32 and Time64 types are + // not supported if the TimeUnit is not valid (Time32 can + // only be used with Second and Millisecond, Time64 only + // with Microsecond and Nanosecond) + DataType::Float16 + | DataType::Time32(TimeUnit::Microsecond) + | DataType::Time32(TimeUnit::Nanosecond) + | DataType::Time64(TimeUnit::Second) + | DataType::Time64(TimeUnit::Millisecond) + | DataType::Duration(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::Union(_, _) + | DataType::Map(_, _) + | DataType::RunEndEncoded(_, _) => { + todo!("make a better error"); + } + }; + + Ok(array) +} + +fn dict_from_values( + values_array: ArrayRef, +) -> Result { + // Create a key array with `size` elements of 0..array_len for all + // non-null value elements + let key_array: PrimitiveArray = (0..values_array.len()) + .map(|index| { + if values_array.is_valid(index) { + let native_index = K::Native::from_usize(index).ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not create index of type {} from value {}", + K::DATA_TYPE, + index + )) + })?; + Ok(Some(native_index)) + } else { + Ok(None) + } + }) + .collect::, DataFusionError>>()? + .into_iter() + .collect(); + + // create a new DictionaryArray + // + // Note: this path could be made faster by using the ArrayData + // APIs and skipping validation, if it every comes up in + // performance traces. + let dict_array = DictionaryArray::::try_new(key_array, values_array)?; + Ok(Arc::new(dict_array)) +} + +fn iter_to_array_list( + scalars: impl IntoIterator>, + data_type: &DataType, +) -> Result, ExtensionError> { + let mut offsets = Int32Array::builder(0); + offsets.append_value(0); + + let mut elements: Vec = Vec::new(); + let mut valid = BooleanBufferBuilder::new(0); + let mut flat_len = 0i32; + for scalar in scalars { + let scalar = scalar?; + if let ScalarValue::List(values, field) = scalar { + match values { + Some(values) => { + let element_array = if !values.is_empty() { + ScalarValue::iter_to_array(values)? + } else { + new_empty_array(field.data_type()) + }; + + // Add new offset index + flat_len += element_array.len() as i32; + offsets.append_value(flat_len); + + elements.push(element_array); + + // Element is valid + valid.append(true); + } + None => { + // Repeat previous offset index + offsets.append_value(flat_len); + + // Element is null + valid.append(false); + } + } + } else { + return Err(ExtensionError::String(format!( + "Expected ScalarValue::List element. Received {scalar:?}" + ))); + } + } + + // Concatenate element arrays to create single flat array + let element_arrays: Vec<&dyn Array> = elements.iter().map(|a| a.as_ref()).collect(); + let flat_array = match datafusion::arrow::compute::concat(&element_arrays) { + Ok(flat_array) => flat_array, + Err(err) => return Ok(Err(DataFusionError::ArrowError(err))?), + }; + + // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices + let offsets_array = offsets.finish(); + let array_data = ArrayDataBuilder::new(data_type.clone()) + .len(offsets_array.len() - 1) + .nulls(Some(NullBuffer::new(valid.finish()))) + .add_buffer(offsets_array.values().inner().clone()) + .add_child_data(flat_array.to_data()); + + let list_array = ListArray::from(array_data.build()?); + Ok(list_array) +} + +fn iter_to_decimal_array( + scalars: impl IntoIterator>, + precision: u8, + scale: i8, +) -> Result { + let array = scalars + .into_iter() + .map( + |element: Result| match element? { + ScalarValue::Decimal128(v1, _, _) => Ok(v1), + _ => unreachable!(), + }, + ) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) +} + +fn iter_to_decimal256_array( + scalars: impl IntoIterator>, + precision: u8, + scale: i8, +) -> Result { + let array = scalars + .into_iter() + .map( + |element: Result| match element? { + ScalarValue::Decimal256(v1, _, _) => Ok(v1), + _ => unreachable!(), + }, + ) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) +} + +fn iter_to_null_array( + scalars: impl IntoIterator>, +) -> ArrayRef { + let length = + scalars + .into_iter() + .fold( + 0usize, + |r, element: Result| match element { + Ok(ScalarValue::Null) => r + 1, + _ => unreachable!(), + }, + ); + new_null_array(&DataType::Null, length) +} + +#[cfg(test)] +mod tests { + use std::cmp::Ordering; + use std::sync::Arc; + + use arrow::compute::kernels; + use arrow::compute::{concat, is_null}; + use arrow::datatypes::ArrowPrimitiveType; + use arrow::util::pretty::pretty_format_columns; + use arrow_array::ArrowNumericType; + use chrono::NaiveDate; + use rand::Rng; + + use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; + + use super::*; + + #[test] + fn scalar_add_trait_test() -> Result<()> { + let float_value = ScalarValue::Float64(Some(123.)); + let float_value_2 = ScalarValue::Float64(Some(123.)); + assert_eq!( + (float_value.add(&float_value_2))?, + ScalarValue::Float64(Some(246.)) + ); + assert_eq!( + (float_value.add(float_value_2))?, + ScalarValue::Float64(Some(246.)) + ); + Ok(()) + } + + #[test] + fn scalar_sub_trait_test() -> Result<()> { + let float_value = ScalarValue::Float64(Some(123.)); + let float_value_2 = ScalarValue::Float64(Some(123.)); + assert_eq!( + float_value.sub(&float_value_2)?, + ScalarValue::Float64(Some(0.)) + ); + assert_eq!( + float_value.sub(float_value_2)?, + ScalarValue::Float64(Some(0.)) + ); + Ok(()) + } + + #[test] + fn scalar_sub_trait_int32_test() -> Result<()> { + let int_value = ScalarValue::Int32(Some(42)); + let int_value_2 = ScalarValue::Int32(Some(100)); + assert_eq!(int_value.sub(&int_value_2)?, ScalarValue::Int32(Some(-58))); + assert_eq!(int_value_2.sub(int_value)?, ScalarValue::Int32(Some(58))); + Ok(()) + } + + #[test] + fn scalar_sub_trait_int32_overflow_test() { + let int_value = ScalarValue::Int32(Some(i32::MAX)); + let int_value_2 = ScalarValue::Int32(Some(i32::MIN)); + let err = int_value + .sub_checked(&int_value_2) + .unwrap_err() + .strip_backtrace(); + assert_eq!( + err, + "Arrow error: Compute error: Overflow happened on: 2147483647 - -2147483648" + ) + } + + #[test] + fn scalar_sub_trait_int64_test() -> Result<()> { + let int_value = ScalarValue::Int64(Some(42)); + let int_value_2 = ScalarValue::Int64(Some(100)); + assert_eq!(int_value.sub(&int_value_2)?, ScalarValue::Int64(Some(-58))); + assert_eq!(int_value_2.sub(int_value)?, ScalarValue::Int64(Some(58))); + Ok(()) + } + + #[test] + fn scalar_sub_trait_int64_overflow_test() { + let int_value = ScalarValue::Int64(Some(i64::MAX)); + let int_value_2 = ScalarValue::Int64(Some(i64::MIN)); + let err = int_value + .sub_checked(&int_value_2) + .unwrap_err() + .strip_backtrace(); + assert_eq!(err, "Arrow error: Compute error: Overflow happened on: 9223372036854775807 - -9223372036854775808") + } + + #[test] + fn scalar_add_overflow_test() -> Result<()> { + check_scalar_add_overflow::( + ScalarValue::Int8(Some(i8::MAX)), + ScalarValue::Int8(Some(i8::MAX)), + ); + check_scalar_add_overflow::( + ScalarValue::UInt8(Some(u8::MAX)), + ScalarValue::UInt8(Some(u8::MAX)), + ); + check_scalar_add_overflow::( + ScalarValue::Int16(Some(i16::MAX)), + ScalarValue::Int16(Some(i16::MAX)), + ); + check_scalar_add_overflow::( + ScalarValue::UInt16(Some(u16::MAX)), + ScalarValue::UInt16(Some(u16::MAX)), + ); + check_scalar_add_overflow::( + ScalarValue::Int32(Some(i32::MAX)), + ScalarValue::Int32(Some(i32::MAX)), + ); + check_scalar_add_overflow::( + ScalarValue::UInt32(Some(u32::MAX)), + ScalarValue::UInt32(Some(u32::MAX)), + ); + check_scalar_add_overflow::( + ScalarValue::Int64(Some(i64::MAX)), + ScalarValue::Int64(Some(i64::MAX)), + ); + check_scalar_add_overflow::( + ScalarValue::UInt64(Some(u64::MAX)), + ScalarValue::UInt64(Some(u64::MAX)), + ); + + Ok(()) + } + + // Verifies that ScalarValue has the same behavior with compute kernal when it overflows. + fn check_scalar_add_overflow(left: ScalarValue, right: ScalarValue) + where + T: ArrowNumericType, + { + let scalar_result = left.add_checked(&right); + + let left_array = left.to_array(); + let right_array = right.to_array(); + let arrow_left_array = left_array.as_primitive::(); + let arrow_right_array = right_array.as_primitive::(); + let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); + + assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); + } + + #[test] + fn test_interval_add_timestamp() -> Result<()> { + let interval = ScalarValue::IntervalMonthDayNano(Some(123)); + let timestamp = ScalarValue::TimestampNanosecond(Some(123), None); + let result = interval.add(×tamp)?; + let expect = timestamp.add(&interval)?; + assert_eq!(result, expect); + + let interval = ScalarValue::IntervalYearMonth(Some(123)); + let timestamp = ScalarValue::TimestampNanosecond(Some(123), None); + let result = interval.add(×tamp)?; + let expect = timestamp.add(&interval)?; + assert_eq!(result, expect); + + let interval = ScalarValue::IntervalDayTime(Some(123)); + let timestamp = ScalarValue::TimestampNanosecond(Some(123), None); + let result = interval.add(×tamp)?; + let expect = timestamp.add(&interval)?; + assert_eq!(result, expect); + Ok(()) + } + + #[test] + fn scalar_decimal_test() -> Result<()> { + let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); + assert_eq!(DataType::Decimal128(10, 1), decimal_value.data_type()); + let try_into_value: i128 = decimal_value.clone().try_into().unwrap(); + assert_eq!(123_i128, try_into_value); + assert!(!decimal_value.is_null()); + let neg_decimal_value = decimal_value.arithmetic_negate()?; + match neg_decimal_value { + ScalarValue::Decimal128(v, _, _) => { + assert_eq!(-123, v.unwrap()); + } + _ => { + unreachable!(); + } + } + + // decimal scalar to array + let array = decimal_value.to_array(); + let array = as_decimal128_array(&array)?; + assert_eq!(1, array.len()); + assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); + assert_eq!(123i128, array.value(0)); + + // decimal scalar to array with size + let array = decimal_value.to_array_of_size(10); + let array_decimal = as_decimal128_array(&array)?; + assert_eq!(10, array.len()); + assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); + assert_eq!(123i128, array_decimal.value(0)); + assert_eq!(123i128, array_decimal.value(9)); + // test eq array + assert!(decimal_value.eq_array(&array, 1)); + assert!(decimal_value.eq_array(&array, 5)); + // test try from array + assert_eq!( + decimal_value, + ScalarValue::try_from_array(&array, 5).unwrap() + ); + + assert_eq!( + decimal_value, + ScalarValue::try_new_decimal128(123, 10, 1).unwrap() + ); + + // test compare + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + assert!(!left.eq(&right)); + let result = left < right; + assert!(result); + let result = left <= right; + assert!(result); + let right = ScalarValue::Decimal128(Some(124), 10, 3); + // make sure that two decimals with diff datatype can't be compared. + let result = left.partial_cmp(&right); + assert_eq!(None, result); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), 10, 2), + ScalarValue::Decimal128(Some(2), 10, 2), + ScalarValue::Decimal128(Some(3), 10, 2), + ]; + // convert the vec to decimal array and check the result + let array = ScalarValue::iter_to_array(decimal_vec).unwrap(); + assert_eq!(3, array.len()); + assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), 10, 2), + ScalarValue::Decimal128(Some(2), 10, 2), + ScalarValue::Decimal128(Some(3), 10, 2), + ScalarValue::Decimal128(None, 10, 2), + ]; + let array = ScalarValue::iter_to_array(decimal_vec).unwrap(); + assert_eq!(4, array.len()); + assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); + + assert!(ScalarValue::try_new_decimal128(1, 10, 2) + .unwrap() + .eq_array(&array, 0)); + assert!(ScalarValue::try_new_decimal128(2, 10, 2) + .unwrap() + .eq_array(&array, 1)); + assert!(ScalarValue::try_new_decimal128(3, 10, 2) + .unwrap() + .eq_array(&array, 2)); + assert_eq!( + ScalarValue::Decimal128(None, 10, 2), + ScalarValue::try_from_array(&array, 3).unwrap() + ); + + Ok(()) + } + + #[test] + fn scalar_value_to_array_u64() -> Result<()> { + let value = ScalarValue::UInt64(Some(13u64)); + let array = value.to_array(); + let array = as_uint64_array(&array)?; + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + assert_eq!(array.value(0), 13); + + let value = ScalarValue::UInt64(None); + let array = value.to_array(); + let array = as_uint64_array(&array)?; + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + Ok(()) + } + + #[test] + fn scalar_value_to_array_u32() -> Result<()> { + let value = ScalarValue::UInt32(Some(13u32)); + let array = value.to_array(); + let array = as_uint32_array(&array)?; + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + assert_eq!(array.value(0), 13); + + let value = ScalarValue::UInt32(None); + let array = value.to_array(); + let array = as_uint32_array(&array)?; + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + Ok(()) + } + + #[test] + fn scalar_list_null_to_array() { + let list_array_ref = + ScalarValue::List(None, Arc::new(Field::new("item", DataType::UInt64, false))) + .to_array(); + let list_array = as_list_array(&list_array_ref).unwrap(); + + assert!(list_array.is_null(0)); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + + #[test] + fn scalar_list_to_array() -> Result<()> { + let list_array_ref = ScalarValue::List( + Some(vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]), + Arc::new(Field::new("item", DataType::UInt64, false)), + ) + .to_array(); + + let list_array = as_list_array(&list_array_ref)?; + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = as_uint64_array(&prim_array_ref)?; + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + Ok(()) + } + + /// Creates array directly and via ScalarValue and ensures they are the same + macro_rules! check_scalar_iter { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they are the same + /// but for variants that carry a timezone field. + macro_rules! check_scalar_iter_tz { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(*v, None)) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for string arrays + macro_rules! check_scalar_iter_string { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for binary arrays + macro_rules! check_scalar_iter_binary { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: $ARRAYTYPE = $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); + + let expected: ArrayRef = Arc::new(expected); + + assert_eq!(&array, &expected); + }}; + } + + #[test] + // despite clippy claiming they are useless, the code doesn't compile otherwise. + #[allow(clippy::useless_vec)] + fn scalar_iter_to_array_boolean() { + check_scalar_iter!(Boolean, BooleanArray, vec![Some(true), None, Some(false)]); + check_scalar_iter!(Float32, Float32Array, vec![Some(1.9), None, Some(-2.1)]); + check_scalar_iter!(Float64, Float64Array, vec![Some(1.9), None, Some(-2.1)]); + + check_scalar_iter!(Int8, Int8Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int16, Int16Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int32, Int32Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int64, Int64Array, vec![Some(1), None, Some(3)]); + + check_scalar_iter!(UInt8, UInt8Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt16, UInt16Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt32, UInt32Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt64, UInt64Array, vec![Some(1), None, Some(3)]); + + check_scalar_iter_tz!( + TimestampSecond, + TimestampSecondArray, + vec![Some(1), None, Some(3)] + ); + check_scalar_iter_tz!( + TimestampMillisecond, + TimestampMillisecondArray, + vec![Some(1), None, Some(3)] + ); + check_scalar_iter_tz!( + TimestampMicrosecond, + TimestampMicrosecondArray, + vec![Some(1), None, Some(3)] + ); + check_scalar_iter_tz!( + TimestampNanosecond, + TimestampNanosecondArray, + vec![Some(1), None, Some(3)] + ); + + check_scalar_iter_string!(Utf8, StringArray, vec![Some("foo"), None, Some("bar")]); + check_scalar_iter_string!( + LargeUtf8, + LargeStringArray, + vec![Some("foo"), None, Some("bar")] + ); + check_scalar_iter_binary!(Binary, BinaryArray, vec![Some(b"foo"), None, Some(b"bar")]); + check_scalar_iter_binary!( + LargeBinary, + LargeBinaryArray, + vec![Some(b"foo"), None, Some(b"bar")] + ); + } + + #[test] + fn scalar_iter_to_array_empty() { + let scalars = vec![] as Vec; + + let result = ScalarValue::iter_to_array(scalars).unwrap_err(); + assert!( + result + .to_string() + .contains("Empty iterator passed to ScalarValue::iter_to_array"), + "{}", + result + ); + } + + #[test] + fn scalar_iter_to_dictionary() { + fn make_val(v: Option) -> ScalarValue { + let key_type = DataType::Int32; + let value = ScalarValue::Utf8(v); + ScalarValue::Dictionary(Box::new(key_type), Box::new(value)) + } + + let scalars = [ + make_val(Some("Foo".into())), + make_val(None), + make_val(Some("Bar".into())), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let array = as_dictionary_array::(&array).unwrap(); + let values_array = as_string_array(array.values()).unwrap(); + + let values = array + .keys_iter() + .map(|k| { + k.map(|k| { + assert!(values_array.is_valid(k)); + values_array.value(k) + }) + }) + .collect::>(); + + let expected = vec![Some("Foo"), None, Some("Bar")]; + assert_eq!(values, expected); + } + + #[test] + fn scalar_iter_to_array_mismatched_types() { + use ScalarValue::*; + // If the scalar values are not all the correct type, error here + let scalars = [Boolean(Some(true)), Int32(Some(5))]; + + let result = ScalarValue::iter_to_array(scalars).unwrap_err(); + assert!( + result.to_string().contains( + "Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)" + ), + "{}", + result + ); + } + + #[test] + fn scalar_try_from_array_null() { + let array = vec![Some(33), None].into_iter().collect::(); + let array: ArrayRef = Arc::new(array); + + assert_eq!( + ScalarValue::Int64(Some(33)), + ScalarValue::try_from_array(&array, 0).unwrap() + ); + assert_eq!( + ScalarValue::Int64(None), + ScalarValue::try_from_array(&array, 1).unwrap() + ); + } + + #[test] + fn scalar_try_from_dict_datatype() { + let data_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + let data_type = &data_type; + let expected = + ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new(ScalarValue::Utf8(None))); + assert_eq!(expected, data_type.try_into().unwrap()) + } + + // this test fails on aarch, so don't run it there + #[cfg(not(target_arch = "aarch64"))] + #[test] + fn size_of_scalar() { + // Since ScalarValues are used in a non trivial number of places, + // making it larger means significant more memory consumption + // per distinct value. + // + // The alignment requirements differ across architectures and + // thus the size of the enum appears to as as well + + assert_eq!(std::mem::size_of::(), 48); + } + + #[test] + fn memory_size() { + let sv = ScalarValue::Binary(Some(Vec::with_capacity(10))); + assert_eq!(sv.size(), std::mem::size_of::() + 10,); + let sv_size = sv.size(); + + let mut v = Vec::with_capacity(10); + // do NOT clone `sv` here because this may shrink the vector capacity + v.push(sv); + assert_eq!(v.capacity(), 10); + assert_eq!( + ScalarValue::size_of_vec(&v), + std::mem::size_of::>() + + (9 * std::mem::size_of::()) + + sv_size, + ); + + let mut s = HashSet::with_capacity(0); + // do NOT clone `sv` here because this may shrink the vector capacity + s.insert(v.pop().unwrap()); + // hashsets may easily grow during insert, so capacity is dynamic + let s_capacity = s.capacity(); + assert_eq!( + ScalarValue::size_of_hashset(&s), + std::mem::size_of::>() + + ((s_capacity - 1) * std::mem::size_of::()) + + sv_size, + ); + } + + #[test] + fn scalar_eq_array() { + // Validate that eq_array has the same semantics as ScalarValue::eq + macro_rules! make_typed_vec { + ($INPUT:expr, $TYPE:ident) => {{ + $INPUT + .iter() + .map(|v| v.map(|v| v as $TYPE)) + .collect::>() + }}; + } + + let bool_vals = [Some(true), None, Some(false)]; + let f32_vals = [Some(-1.0), None, Some(1.0)]; + let f64_vals = make_typed_vec!(f32_vals, f64); + + let i8_vals = [Some(-1), None, Some(1)]; + let i16_vals = make_typed_vec!(i8_vals, i16); + let i32_vals = make_typed_vec!(i8_vals, i32); + let i64_vals = make_typed_vec!(i8_vals, i64); + + let u8_vals = [Some(0), None, Some(1)]; + let u16_vals = make_typed_vec!(u8_vals, u16); + let u32_vals = make_typed_vec!(u8_vals, u32); + let u64_vals = make_typed_vec!(u8_vals, u64); + + let str_vals = [Some("foo"), None, Some("bar")]; + + /// Test each value in `scalar` with the corresponding element + /// at `array`. Assumes each element is unique (aka not equal + /// with all other indexes) + #[derive(Debug)] + struct TestCase { + array: ArrayRef, + scalars: Vec, + } + + /// Create a test case for casing the input to the specified array type + macro_rules! make_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ + let tz = $TZ; + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(*v, tz.clone())) + .collect(), + } + }}; + } + + macro_rules! make_str_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .collect(), + } + }}; + } + + macro_rules! make_binary_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec()))) + .collect(), + } + }}; + } + + /// create a test case for DictionaryArray<$INDEX_TY> + macro_rules! make_str_dict_test_case { + ($INPUT:expr, $INDEX_TY:ident) => {{ + TestCase { + array: Arc::new( + $INPUT + .iter() + .cloned() + .collect::>(), + ), + scalars: $INPUT + .iter() + .map(|v| { + ScalarValue::Dictionary( + Box::new($INDEX_TY::DATA_TYPE), + Box::new(ScalarValue::Utf8(v.map(|v| v.to_string()))), + ) + }) + .collect(), + } + }}; + } + + let cases = vec![ + make_test_case!(bool_vals, BooleanArray, Boolean), + make_test_case!(f32_vals, Float32Array, Float32), + make_test_case!(f64_vals, Float64Array, Float64), + make_test_case!(i8_vals, Int8Array, Int8), + make_test_case!(i16_vals, Int16Array, Int16), + make_test_case!(i32_vals, Int32Array, Int32), + make_test_case!(i64_vals, Int64Array, Int64), + make_test_case!(u8_vals, UInt8Array, UInt8), + make_test_case!(u16_vals, UInt16Array, UInt16), + make_test_case!(u32_vals, UInt32Array, UInt32), + make_test_case!(u64_vals, UInt64Array, UInt64), + make_str_test_case!(str_vals, StringArray, Utf8), + make_str_test_case!(str_vals, LargeStringArray, LargeUtf8), + make_binary_test_case!(str_vals, BinaryArray, Binary), + make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), + make_test_case!(i32_vals, Date32Array, Date32), + make_test_case!(i64_vals, Date64Array, Date64), + make_test_case!(i32_vals, Time32SecondArray, Time32Second), + make_test_case!(i32_vals, Time32MillisecondArray, Time32Millisecond), + make_test_case!(i64_vals, Time64MicrosecondArray, Time64Microsecond), + make_test_case!(i64_vals, Time64NanosecondArray, Time64Nanosecond), + make_test_case!(i64_vals, TimestampSecondArray, TimestampSecond, None), + make_test_case!( + i64_vals, + TimestampSecondArray, + TimestampSecond, + Some("UTC".into()) + ), + make_test_case!( + i64_vals, + TimestampMillisecondArray, + TimestampMillisecond, + None + ), + make_test_case!( + i64_vals, + TimestampMillisecondArray, + TimestampMillisecond, + Some("UTC".into()) + ), + make_test_case!( + i64_vals, + TimestampMicrosecondArray, + TimestampMicrosecond, + None + ), + make_test_case!( + i64_vals, + TimestampMicrosecondArray, + TimestampMicrosecond, + Some("UTC".into()) + ), + make_test_case!( + i64_vals, + TimestampNanosecondArray, + TimestampNanosecond, + None + ), + make_test_case!( + i64_vals, + TimestampNanosecondArray, + TimestampNanosecond, + Some("UTC".into()) + ), + make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth), + make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime), + make_str_dict_test_case!(str_vals, Int8Type), + make_str_dict_test_case!(str_vals, Int16Type), + make_str_dict_test_case!(str_vals, Int32Type), + make_str_dict_test_case!(str_vals, Int64Type), + make_str_dict_test_case!(str_vals, UInt8Type), + make_str_dict_test_case!(str_vals, UInt16Type), + make_str_dict_test_case!(str_vals, UInt32Type), + make_str_dict_test_case!(str_vals, UInt64Type), + ]; + + for case in cases { + println!("**** Test Case *****"); + let TestCase { array, scalars } = case; + println!("Input array type: {}", array.data_type()); + println!("Input scalars: {scalars:#?}"); + assert_eq!(array.len(), scalars.len()); + + for (index, scalar) in scalars.into_iter().enumerate() { + assert!( + scalar.eq_array(&array, index), + "Expected {scalar:?} to be equal to {array:?} at index {index}" + ); + + // test that all other elements are *not* equal + for other_index in 0..array.len() { + if index != other_index { + assert!( + !scalar.eq_array(&array, other_index), + "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" + ); + } + } + } + } + } + + #[test] + fn scalar_partial_ordering() { + use ScalarValue::*; + + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(0))), + Some(Ordering::Greater) + ); + assert_eq!( + Int64(Some(0)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Less) + ); + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Equal) + ); + // For different data type, `partial_cmp` returns None. + assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None); + assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); + + assert_eq!( + List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Arc::new(Field::new("item", DataType::Int32, false)), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Arc::new(Field::new("item", DataType::Int32, false)), + )), + Some(Ordering::Equal) + ); + + assert_eq!( + List( + Some(vec![Int32(Some(10)), Int32(Some(5))]), + Arc::new(Field::new("item", DataType::Int32, false)), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Arc::new(Field::new("item", DataType::Int32, false)), + )), + Some(Ordering::Greater) + ); + + assert_eq!( + List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Arc::new(Field::new("item", DataType::Int32, false)), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(10)), Int32(Some(5))]), + Arc::new(Field::new("item", DataType::Int32, false)), + )), + Some(Ordering::Less) + ); + + // For different data type, `partial_cmp` returns None. + assert_eq!( + List( + Some(vec![Int64(Some(1)), Int64(Some(5))]), + Arc::new(Field::new("item", DataType::Int64, false)), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Arc::new(Field::new("item", DataType::Int32, false)), + )), + None + ); + + assert_eq!( + ScalarValue::from(vec![ + ("A", ScalarValue::from(1.0)), + ("B", ScalarValue::from("Z")), + ]) + .partial_cmp(&ScalarValue::from(vec![ + ("A", ScalarValue::from(2.0)), + ("B", ScalarValue::from("A")), + ])), + Some(Ordering::Less) + ); + + // For different struct fields, `partial_cmp` returns None. + assert_eq!( + ScalarValue::from(vec![ + ("A", ScalarValue::from(1.0)), + ("B", ScalarValue::from("Z")), + ]) + .partial_cmp(&ScalarValue::from(vec![ + ("a", ScalarValue::from(2.0)), + ("b", ScalarValue::from("A")), + ])), + None + ); + } + + #[test] + fn test_scalar_struct() { + let field_a = Arc::new(Field::new("A", DataType::Int32, false)); + let field_b = Arc::new(Field::new("B", DataType::Boolean, false)); + let field_c = Arc::new(Field::new("C", DataType::Utf8, false)); + + let field_e = Arc::new(Field::new("e", DataType::Int16, false)); + let field_f = Arc::new(Field::new("f", DataType::Int64, false)); + let field_d = Arc::new(Field::new( + "D", + DataType::Struct(vec![field_e.clone(), field_f.clone()].into()), + false, + )); + + let scalar = ScalarValue::Struct( + Some(vec![ + ScalarValue::Int32(Some(23)), + ScalarValue::Boolean(Some(false)), + ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ]), + vec![ + field_a.clone(), + field_b.clone(), + field_c.clone(), + field_d.clone(), + ] + .into(), + ); + + // Check Display + assert_eq!( + format!("{scalar}"), + String::from("{A:23,B:false,C:Hello,D:{e:2,f:3}}") + ); + + // Check Debug + assert_eq!( + format!("{scalar:?}"), + String::from( + r#"Struct({A:Int32(23),B:Boolean(false),C:Utf8("Hello"),D:Struct({e:Int16(2),f:Int64(3)})})"# + ) + ); + + // Convert to length-2 array + let array = scalar.to_array_of_size(2); + + let expected = Arc::new(StructArray::from(vec![ + ( + field_a.clone(), + Arc::new(Int32Array::from(vec![23, 23])) as ArrayRef, + ), + ( + field_b.clone(), + Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, + ), + ( + field_c.clone(), + Arc::new(StringArray::from(vec!["Hello", "Hello"])) as ArrayRef, + ), + ( + field_d.clone(), + Arc::new(StructArray::from(vec![ + ( + field_e.clone(), + Arc::new(Int16Array::from(vec![2, 2])) as ArrayRef, + ), + ( + field_f.clone(), + Arc::new(Int64Array::from(vec![3, 3])) as ArrayRef, + ), + ])) as ArrayRef, + ), + ])) as ArrayRef; + + assert_eq!(&array, &expected); + + // Construct from second element of ArrayRef + let constructed = ScalarValue::try_from_array(&expected, 1).unwrap(); + assert_eq!(constructed, scalar); + + // None version + let none_scalar = ScalarValue::try_from(array.data_type()).unwrap(); + assert!(none_scalar.is_null()); + assert_eq!(format!("{none_scalar:?}"), String::from("Struct(NULL)")); + + // Construct with convenience From> + let constructed = ScalarValue::from(vec![ + ("A", ScalarValue::from(23)), + ("B", ScalarValue::from(false)), + ("C", ScalarValue::from("Hello")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ), + ]); + assert_eq!(constructed, scalar); + + // Build Array from Vec of structs + let scalars = vec![ + ScalarValue::from(vec![ + ("A", ScalarValue::from(23)), + ("B", ScalarValue::from(false)), + ("C", ScalarValue::from("Hello")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ), + ]), + ScalarValue::from(vec![ + ("A", ScalarValue::from(7)), + ("B", ScalarValue::from(true)), + ("C", ScalarValue::from("World")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(4i16)), + ("f", ScalarValue::from(5i64)), + ]), + ), + ]), + ScalarValue::from(vec![ + ("A", ScalarValue::from(-1000)), + ("B", ScalarValue::from(true)), + ("C", ScalarValue::from("!!!!!")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(6i16)), + ("f", ScalarValue::from(7i64)), + ]), + ), + ]), + ]; + let array = ScalarValue::iter_to_array(scalars).unwrap(); + + let expected = Arc::new(StructArray::from(vec![ + ( + field_a, + Arc::new(Int32Array::from(vec![23, 7, -1000])) as ArrayRef, + ), + ( + field_b, + Arc::new(BooleanArray::from(vec![false, true, true])) as ArrayRef, + ), + ( + field_c, + Arc::new(StringArray::from(vec!["Hello", "World", "!!!!!"])) as ArrayRef, + ), + ( + field_d, + Arc::new(StructArray::from(vec![ + ( + field_e, + Arc::new(Int16Array::from(vec![2, 4, 6])) as ArrayRef, + ), + ( + field_f, + Arc::new(Int64Array::from(vec![3, 5, 7])) as ArrayRef, + ), + ])) as ArrayRef, + ), + ])) as ArrayRef; + + assert_eq!(&array, &expected); + } + + #[test] + fn test_lists_in_struct() { + let field_a = Arc::new(Field::new("A", DataType::Utf8, false)); + let field_primitive_list = Arc::new(Field::new( + "primitive_list", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + )); + + // Define primitive list scalars + let l0 = ScalarValue::List( + Some(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ]), + Arc::new(Field::new("item", DataType::Int32, false)), + ); + + let l1 = ScalarValue::List( + Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), + Arc::new(Field::new("item", DataType::Int32, false)), + ); + + let l2 = ScalarValue::List( + Some(vec![ScalarValue::from(6i32)]), + Arc::new(Field::new("item", DataType::Int32, false)), + ); + + // Define struct scalars + let s0 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("primitive_list", l0), + ]); + + let s1 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("primitive_list", l1), + ]); + + let s2 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("primitive_list", l2), + ]); + + // iter_to_array for struct scalars + let array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap(); + let array = as_struct_array(&array).unwrap(); + let expected = StructArray::from(vec![ + ( + field_a.clone(), + Arc::new(StringArray::from(vec!["First", "Second", "Third"])) as ArrayRef, + ), + ( + field_primitive_list.clone(), + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6)]), + ])), + ), + ]); + + assert_eq!(array, &expected); + + // Define list-of-structs scalars + let nl0 = ScalarValue::new_list(Some(vec![s0.clone(), s1.clone()]), s0.data_type()); + + let nl1 = ScalarValue::new_list(Some(vec![s2]), s0.data_type()); + + let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.data_type()); + // iter_to_array for list-of-struct + let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); + let array = as_list_array(&array).unwrap(); + + // Construct expected array with array builders + let field_a_builder = StringBuilder::with_capacity(4, 1024); + let primitive_value_builder = Int32Array::builder(8); + let field_primitive_list_builder = ListBuilder::new(primitive_value_builder); + + let element_builder = StructBuilder::new( + vec![field_a, field_primitive_list], + vec![ + Box::new(field_a_builder), + Box::new(field_primitive_list_builder), + ], + ); + let mut list_builder = ListBuilder::new(element_builder); + + list_builder + .values() + .field_builder::(0) + .unwrap() + .append_value("First"); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(1); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(2); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(3); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .append(true); + list_builder.values().append(true); + + list_builder + .values() + .field_builder::(0) + .unwrap() + .append_value("Second"); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(4); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(5); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .append(true); + list_builder.values().append(true); + list_builder.append(true); + + list_builder + .values() + .field_builder::(0) + .unwrap() + .append_value("Third"); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(6); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .append(true); + list_builder.values().append(true); + list_builder.append(true); + + list_builder + .values() + .field_builder::(0) + .unwrap() + .append_value("Second"); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(4); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .values() + .append_value(5); + list_builder + .values() + .field_builder::>>(1) + .unwrap() + .append(true); + list_builder.values().append(true); + list_builder.append(true); + + let expected = list_builder.finish(); + + assert_eq!(array, &expected); + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let l1 = ScalarValue::new_list( + Some(vec![ + ScalarValue::new_list( + Some(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ]), + DataType::Int32, + ), + ScalarValue::new_list( + Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), + DataType::Int32, + ), + ]), + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + ); + + let l2 = ScalarValue::new_list( + Some(vec![ + ScalarValue::new_list(Some(vec![ScalarValue::from(6i32)]), DataType::Int32), + ScalarValue::new_list( + Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), + DataType::Int32, + ), + ]), + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + ); + + let l3 = ScalarValue::new_list( + Some(vec![ScalarValue::new_list( + Some(vec![ScalarValue::from(9i32)]), + DataType::Int32, + )]), + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + ); + + let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); + let array = as_list_array(&array).unwrap(); + + // Construct expected array with array builders + let inner_builder = Int32Array::builder(8); + let middle_builder = ListBuilder::new(inner_builder); + let mut outer_builder = ListBuilder::new(middle_builder); + + outer_builder.values().values().append_value(1); + outer_builder.values().values().append_value(2); + outer_builder.values().values().append_value(3); + outer_builder.values().append(true); + + outer_builder.values().values().append_value(4); + outer_builder.values().values().append_value(5); + outer_builder.values().append(true); + outer_builder.append(true); + + outer_builder.values().values().append_value(6); + outer_builder.values().append(true); + + outer_builder.values().values().append_value(7); + outer_builder.values().values().append_value(8); + outer_builder.values().append(true); + outer_builder.append(true); + + outer_builder.values().values().append_value(9); + outer_builder.values().append(true); + outer_builder.append(true); + + let expected = outer_builder.finish(); + + assert_eq!(array, &expected); + } + + #[test] + fn scalar_timestamp_ns_utc_timezone() { + let scalar = + ScalarValue::TimestampNanosecond(Some(1599566400000000000), Some("UTC".into())); + + assert_eq!( + scalar.data_type(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) + ); + + let array = scalar.to_array(); + assert_eq!(array.len(), 1); + assert_eq!( + array.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) + ); + + let newscalar = ScalarValue::try_from_array(&array, 0).unwrap(); + assert_eq!( + newscalar.data_type(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) + ); + } + + #[test] + fn cast_round_trip() { + check_scalar_cast(ScalarValue::Int8(Some(5)), DataType::Int16); + check_scalar_cast(ScalarValue::Int8(None), DataType::Int16); + + check_scalar_cast(ScalarValue::Float64(Some(5.5)), DataType::Int16); + + check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); + + check_scalar_cast( + ScalarValue::Utf8(Some("foo".to_string())), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + ); + + check_scalar_cast( + ScalarValue::Utf8(None), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + ); + } + + // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` + fn check_scalar_cast(scalar: ScalarValue, desired_type: DataType) { + // convert from scalar --> Array to call cast + let scalar_array = scalar.to_array(); + // cast the actual value + let cast_array = kernels::cast::cast(&scalar_array, &desired_type).unwrap(); + + // turn it back to a scalar + let cast_scalar = ScalarValue::try_from_array(&cast_array, 0).unwrap(); + assert_eq!(cast_scalar.data_type(), desired_type); + + // Some time later the "cast" scalar is turned back into an array: + let array = cast_scalar.to_array_of_size(10); + + // The datatype should be "Dictionary" but is actually Utf8!!! + assert_eq!(array.data_type(), &desired_type) + } + + #[test] + fn test_scalar_negative() -> Result<()> { + // positive test + let value = ScalarValue::Int32(Some(12)); + assert_eq!(ScalarValue::Int32(Some(-12)), value.arithmetic_negate()?); + let value = ScalarValue::Int32(None); + assert_eq!(ScalarValue::Int32(None), value.arithmetic_negate()?); + + // negative test + let value = ScalarValue::UInt8(Some(12)); + assert!(value.arithmetic_negate().is_err()); + let value = ScalarValue::Boolean(None); + assert!(value.arithmetic_negate().is_err()); + Ok(()) + } + + macro_rules! expect_operation_error { + ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => { + #[test] + fn $TEST_NAME() { + let lhs = ScalarValue::UInt64(Some(12)); + let rhs = ScalarValue::Int32(Some(-3)); + match lhs.$FUNCTION(&rhs) { + Ok(_result) => { + panic!( + "Expected binary operation error between lhs: '{:?}', rhs: {:?}", + lhs, rhs + ); + } + Err(e) => { + let error_message = e.to_string(); + assert!( + error_message.contains($EXPECTED_ERROR), + "Expected error '{}' not found in actual error '{}'", + $EXPECTED_ERROR, + error_message + ); + } + } + } + }; + } + + expect_operation_error!( + expect_add_error, + add, + "Invalid arithmetic operation: UInt64 + Int32" + ); + expect_operation_error!( + expect_sub_error, + sub, + "Invalid arithmetic operation: UInt64 - Int32" + ); + + macro_rules! decimal_op_test_cases { + ($OPERATION:ident, [$([$L_VALUE:expr, $L_PRECISION:expr, $L_SCALE:expr, $R_VALUE:expr, $R_PRECISION:expr, $R_SCALE:expr, $O_VALUE:expr, $O_PRECISION:expr, $O_SCALE:expr]),+]) => { + $( + + let left = ScalarValue::Decimal128($L_VALUE, $L_PRECISION, $L_SCALE); + let right = ScalarValue::Decimal128($R_VALUE, $R_PRECISION, $R_SCALE); + let result = left.$OPERATION(&right).unwrap(); + assert_eq!(ScalarValue::Decimal128($O_VALUE, $O_PRECISION, $O_SCALE), result); + + )+ + }; + } + + #[test] + fn decimal_operations() { + decimal_op_test_cases!( + add, + [ + [Some(123), 10, 2, Some(124), 10, 2, Some(123 + 124), 11, 2], + // test sum decimal with diff scale + [ + Some(123), + 10, + 3, + Some(124), + 10, + 2, + Some(123 + 124 * 10_i128.pow(1)), + 12, + 3 + ], + // diff precision and scale for decimal data type + [ + Some(123), + 10, + 2, + Some(124), + 11, + 3, + Some(123 * 10_i128.pow(3 - 2) + 124), + 12, + 3 + ] + ] + ); + } + + #[test] + fn decimal_operations_with_nulls() { + decimal_op_test_cases!( + add, + [ + // Case: (None, Some, 0) + [None, 10, 2, Some(123), 10, 2, None, 11, 2], + // Case: (Some, None, 0) + [Some(123), 10, 2, None, 10, 2, None, 11, 2], + // Case: (Some, None, _) + Side=False + [Some(123), 8, 2, None, 10, 3, None, 11, 3], + // Case: (None, Some, _) + Side=False + [None, 8, 2, Some(123), 10, 3, None, 11, 3], + // Case: (Some, None, _) + Side=True + [Some(123), 8, 4, None, 10, 3, None, 12, 4], + // Case: (None, Some, _) + Side=True + [None, 10, 3, Some(123), 8, 4, None, 12, 4] + ] + ); + } + + #[test] + fn test_scalar_distance() { + let cases = [ + // scalar (lhs), scalar (rhs), expected distance + // --------------------------------------------- + (ScalarValue::Int8(Some(1)), ScalarValue::Int8(Some(2)), 1), + (ScalarValue::Int8(Some(2)), ScalarValue::Int8(Some(1)), 1), + ( + ScalarValue::Int16(Some(-5)), + ScalarValue::Int16(Some(5)), + 10, + ), + ( + ScalarValue::Int16(Some(5)), + ScalarValue::Int16(Some(-5)), + 10, + ), + (ScalarValue::Int32(Some(0)), ScalarValue::Int32(Some(0)), 0), + ( + ScalarValue::Int32(Some(-5)), + ScalarValue::Int32(Some(-10)), + 5, + ), + ( + ScalarValue::Int64(Some(-10)), + ScalarValue::Int64(Some(-5)), + 5, + ), + (ScalarValue::UInt8(Some(1)), ScalarValue::UInt8(Some(2)), 1), + (ScalarValue::UInt8(Some(0)), ScalarValue::UInt8(Some(0)), 0), + ( + ScalarValue::UInt16(Some(5)), + ScalarValue::UInt16(Some(10)), + 5, + ), + ( + ScalarValue::UInt32(Some(10)), + ScalarValue::UInt32(Some(5)), + 5, + ), + ( + ScalarValue::UInt64(Some(5)), + ScalarValue::UInt64(Some(10)), + 5, + ), + ( + ScalarValue::Float32(Some(1.0)), + ScalarValue::Float32(Some(2.0)), + 1, + ), + ( + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + 1, + ), + ( + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(0.0)), + 0, + ), + ( + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-10.0)), + 5, + ), + ( + ScalarValue::Float64(Some(-10.0)), + ScalarValue::Float64(Some(-5.0)), + 5, + ), + // Floats are currently special cased to f64/f32 and the result is rounded + // rather than ceiled/floored. In the future we might want to take a mode + // which specified the rounding behavior. + ( + ScalarValue::Float32(Some(1.2)), + ScalarValue::Float32(Some(1.3)), + 0, + ), + ( + ScalarValue::Float32(Some(1.1)), + ScalarValue::Float32(Some(1.9)), + 1, + ), + ( + ScalarValue::Float64(Some(-5.3)), + ScalarValue::Float64(Some(-9.2)), + 4, + ), + ( + ScalarValue::Float64(Some(-5.3)), + ScalarValue::Float64(Some(-9.7)), + 4, + ), + ( + ScalarValue::Float64(Some(-5.3)), + ScalarValue::Float64(Some(-9.9)), + 5, + ), + ]; + for (lhs, rhs, expected) in cases.iter() { + let distance = lhs.distance(rhs).unwrap(); + assert_eq!(distance, *expected); + } + } + + #[test] + fn test_scalar_distance_invalid() { + let cases = [ + // scalar (lhs), scalar (rhs) + // -------------------------- + // Same type but with nulls + (ScalarValue::Int8(None), ScalarValue::Int8(None)), + (ScalarValue::Int8(None), ScalarValue::Int8(Some(1))), + (ScalarValue::Int8(Some(1)), ScalarValue::Int8(None)), + // Different type + (ScalarValue::Int8(Some(1)), ScalarValue::Int16(Some(1))), + (ScalarValue::Int8(Some(1)), ScalarValue::Float32(Some(1.0))), + ( + ScalarValue::Float64(Some(1.1)), + ScalarValue::Float32(Some(2.2)), + ), + ( + ScalarValue::UInt64(Some(777)), + ScalarValue::Int32(Some(111)), + ), + // Different types with nulls + (ScalarValue::Int8(None), ScalarValue::Int16(Some(1))), + (ScalarValue::Int8(Some(1)), ScalarValue::Int16(None)), + // Unsupported types + ( + ScalarValue::Utf8(Some("foo".to_string())), + ScalarValue::Utf8(Some("bar".to_string())), + ), + ( + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(Some(false)), + ), + (ScalarValue::Date32(Some(0)), ScalarValue::Date32(Some(1))), + (ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))), + ( + ScalarValue::Decimal128(Some(123), 5, 5), + ScalarValue::Decimal128(Some(120), 5, 5), + ), + ]; + for (lhs, rhs) in cases { + let distance = lhs.distance(&rhs); + assert!(distance.is_none()); + } + } + + #[test] + fn test_scalar_interval_negate() { + let cases = [ + ( + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(-1, -12), + ), + ( + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(-1, -999), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(-12, -15, -123_456), + ), + ]; + for (expr, expected) in cases.iter() { + let result = expr.arithmetic_negate().unwrap(); + assert_eq!(*expected, result, "-expr:{expr:?}"); + } + } + + #[test] + fn test_scalar_interval_add() { + let cases = [ + ( + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(2, 24), + ), + ( + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(2, 1998), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(24, 30, 246_912), + ), + ]; + for (lhs, rhs, expected) in cases.iter() { + let result = lhs.add(rhs).unwrap(); + let result_commute = rhs.add(lhs).unwrap(); + assert_eq!(*expected, result, "lhs:{lhs:?} + rhs:{rhs:?}"); + assert_eq!(*expected, result_commute, "lhs:{rhs:?} + rhs:{lhs:?}"); + } + } + + #[test] + fn test_scalar_interval_sub() { + let cases = [ + ( + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(0, 0), + ), + ( + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(0, 0), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(0, 0, 0), + ), + ]; + for (lhs, rhs, expected) in cases.iter() { + let result = lhs.sub(rhs).unwrap(); + assert_eq!(*expected, result, "lhs:{lhs:?} - rhs:{rhs:?}"); + } + } + + #[test] + fn timestamp_op_random_tests() { + // timestamp1 + (or -) interval = timestamp2 + // timestamp2 - timestamp1 (or timestamp1 - timestamp2) = interval ? + let sample_size = 1000; + let timestamps1 = get_random_timestamps(sample_size); + let intervals = get_random_intervals(sample_size); + // ts(sec) + interval(ns) = ts(sec); however, + // ts(sec) - ts(sec) cannot be = interval(ns). Therefore, + // timestamps are more precise than intervals in tests. + for (idx, ts1) in timestamps1.iter().enumerate() { + if idx % 2 == 0 { + let timestamp2 = ts1.add(intervals[idx].clone()).unwrap(); + let back = timestamp2.sub(intervals[idx].clone()).unwrap(); + assert_eq!(ts1, &back); + } else { + let timestamp2 = ts1.sub(intervals[idx].clone()).unwrap(); + let back = timestamp2.add(intervals[idx].clone()).unwrap(); + assert_eq!(ts1, &back); + }; + } + } + + #[test] + fn test_struct_nulls() { + let fields_b = Fields::from(vec![ + Field::new("ba", DataType::UInt64, true), + Field::new("bb", DataType::UInt64, true), + ]); + let fields = Fields::from(vec![ + Field::new("a", DataType::UInt64, true), + Field::new("b", DataType::Struct(fields_b.clone()), true), + ]); + let scalars = vec![ + ScalarValue::Struct(None, fields.clone()), + ScalarValue::Struct( + Some(vec![ + ScalarValue::UInt64(None), + ScalarValue::Struct(None, fields_b.clone()), + ]), + fields.clone(), + ), + ScalarValue::Struct( + Some(vec![ + ScalarValue::UInt64(None), + ScalarValue::Struct( + Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]), + fields_b.clone(), + ), + ]), + fields.clone(), + ), + ScalarValue::Struct( + Some(vec![ + ScalarValue::UInt64(Some(1)), + ScalarValue::Struct( + Some(vec![ + ScalarValue::UInt64(Some(2)), + ScalarValue::UInt64(Some(3)), + ]), + fields_b, + ), + ]), + fields, + ), + ]; + + let check_array = |array| { + let is_null = is_null(&array).unwrap(); + assert_eq!(is_null, BooleanArray::from(vec![true, false, false, false])); + + let formatted = pretty_format_columns("col", &[array]).unwrap().to_string(); + let formatted = formatted.split('\n').collect::>(); + let expected = vec![ + "+---------------------------+", + "| col |", + "+---------------------------+", + "| |", + "| {a: , b: } |", + "| {a: , b: {ba: , bb: }} |", + "| {a: 1, b: {ba: 2, bb: 3}} |", + "+---------------------------+", + ]; + assert_eq!( + formatted, expected, + "Actual:\n{formatted:#?}\n\nExpected:\n{expected:#?}" + ); + }; + + // test `ScalarValue::iter_to_array` + let array = ScalarValue::iter_to_array(scalars.clone()).unwrap(); + check_array(array); + + // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size` + let arrays = scalars + .iter() + .map(ScalarValue::to_array) + .collect::>(); + let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); + let array = concat(&arrays).unwrap(); + check_array(array); + } + + #[test] + fn test_build_timestamp_millisecond_list() { + let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; + let ts_list = ScalarValue::new_list( + Some(values), + DataType::Timestamp(TimeUnit::Millisecond, None), + ); + let list = ts_list.to_array_of_size(1); + assert_eq!(1, list.len()); + } + + fn get_random_timestamps(sample_size: u64) -> Vec { + let vector_size = sample_size; + let mut timestamp = vec![]; + let mut rng = rand::thread_rng(); + for i in 0..vector_size { + let year = rng.gen_range(1995..=2050); + let month = rng.gen_range(1..=12); + let day = rng.gen_range(1..=28); // to exclude invalid dates + let hour = rng.gen_range(0..=23); + let minute = rng.gen_range(0..=59); + let second = rng.gen_range(0..=59); + if i % 4 == 0 { + timestamp.push(ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_opt(hour, minute, second) + .unwrap() + .timestamp(), + ), + None, + )) + } else if i % 4 == 1 { + let millisec = rng.gen_range(0..=999); + timestamp.push(ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_milli_opt(hour, minute, second, millisec) + .unwrap() + .timestamp_millis(), + ), + None, + )) + } else if i % 4 == 2 { + let microsec = rng.gen_range(0..=999_999); + timestamp.push(ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_micro_opt(hour, minute, second, microsec) + .unwrap() + .timestamp_micros(), + ), + None, + )) + } else if i % 4 == 3 { + let nanosec = rng.gen_range(0..=999_999_999); + timestamp.push(ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_nano_opt(hour, minute, second, nanosec) + .unwrap() + .timestamp_nanos_opt() + .unwrap(), + ), + None, + )) + } + } + timestamp + } + + fn get_random_intervals(sample_size: u64) -> Vec { + const MILLISECS_IN_ONE_DAY: i64 = 86_400_000; + const NANOSECS_IN_ONE_DAY: i64 = 86_400_000_000_000; + + let vector_size = sample_size; + let mut intervals = vec![]; + let mut rng = rand::thread_rng(); + const SECS_IN_ONE_DAY: i32 = 86_400; + const MICROSECS_IN_ONE_DAY: i64 = 86_400_000_000; + for i in 0..vector_size { + if i % 4 == 0 { + let days = rng.gen_range(0..5000); + // to not break second precision + let millis = rng.gen_range(0..SECS_IN_ONE_DAY) * 1000; + intervals.push(ScalarValue::new_interval_dt(days, millis)); + } else if i % 4 == 1 { + let days = rng.gen_range(0..5000); + let millisec = rng.gen_range(0..(MILLISECS_IN_ONE_DAY as i32)); + intervals.push(ScalarValue::new_interval_dt(days, millisec)); + } else if i % 4 == 2 { + let days = rng.gen_range(0..5000); + // to not break microsec precision + let nanosec = rng.gen_range(0..MICROSECS_IN_ONE_DAY) * 1000; + intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); + } else { + let days = rng.gen_range(0..5000); + let nanosec = rng.gen_range(0..NANOSECS_IN_ONE_DAY); + intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); + } + } + intervals + } +} diff --git a/crates/datafusion_ext/src/lib.rs b/crates/datafusion_ext/src/lib.rs index 0aa516256..c74405b02 100644 --- a/crates/datafusion_ext/src/lib.rs +++ b/crates/datafusion_ext/src/lib.rs @@ -1,3 +1,4 @@ +pub mod cast; pub mod errors; pub mod metrics; pub mod planner; diff --git a/crates/sqlbuiltins/src/errors.rs b/crates/sqlbuiltins/src/errors.rs index 6f8b7a02e..3293e1b49 100644 --- a/crates/sqlbuiltins/src/errors.rs +++ b/crates/sqlbuiltins/src/errors.rs @@ -1,6 +1,7 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError; +use datafusion_ext::errors::ExtensionError; #[derive(Clone, Debug, thiserror::Error)] pub enum BuiltinError { @@ -33,6 +34,9 @@ pub enum BuiltinError { #[error("ArrowError: {0}")] ArrowError(String), + + #[error("DataFusionExtension: {0}")] + DataFusionExtension(String), } pub type Result = std::result::Result; @@ -43,12 +47,24 @@ impl From for DataFusionError { } } +impl From for ExtensionError { + fn from(e: BuiltinError) -> Self { + ExtensionError::String(e.to_string()) + } +} + impl From for BuiltinError { fn from(e: DataFusionError) -> Self { BuiltinError::DataFusionError(e.to_string()) } } +impl From for BuiltinError { + fn from(e: ExtensionError) -> Self { + BuiltinError::DataFusionExtension(e.to_string()) + } +} + impl From for BuiltinError { fn from(e: ArrowError) -> Self { BuiltinError::ArrowError(e.to_string()) diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index fef3df516..a4f8a8f48 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -11,6 +11,8 @@ use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::logical_expr::{Expr, ScalarUDF, Signature, TypeSignature, Volatility}; use datafusion::physical_plan::ColumnarValue; use datafusion::scalar::ScalarValue; +use datafusion_ext::cast::scalar_iter_to_array; +use datafusion_ext::errors::ExtensionError; use num_traits::ToPrimitive; use crate::document; @@ -62,17 +64,12 @@ fn get_nth_scalar_value( match input.get(n) { Some(input) => match input { ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(scalar.clone())?)), - ColumnarValue::Array(arr) => { - let mut values = Vec::with_capacity(arr.len()); - - for idx in 0..arr.len() { - values.push(op(ScalarValue::try_from_array(arr, idx)?)?); - } - - Ok(ColumnarValue::Array(ScalarValue::iter_to_array( - values.into_iter(), - )?)) - } + ColumnarValue::Array(arr) => Ok(ColumnarValue::Array(scalar_iter_to_array( + arr.data_type(), + (0..arr.len()).map(|idx| -> Result { + Ok(op(ScalarValue::try_from_array(arr, idx)?)?) + }), + )?)), }, None => Err(BuiltinError::MissingValueAtIndex(n)), } From 2aa2277a09d7b0027d8c45ce16253f821de69004 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Wed, 27 Dec 2023 17:20:21 -0500 Subject: [PATCH 2/8] fix formatting --- crates/datafusion_ext/src/cast.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/datafusion_ext/src/cast.rs b/crates/datafusion_ext/src/cast.rs index 710c6a402..2e5602f2b 100644 --- a/crates/datafusion_ext/src/cast.rs +++ b/crates/datafusion_ext/src/cast.rs @@ -27,7 +27,7 @@ pub fn scalar_iter_to_array( data_type: &DataType, scalars: impl IntoIterator>, ) -> Result { - let scalars = scalars.into_iter(); + let scalars = scalars.into_iter(); /// Creates an array of $ARRAY_TY by unpacking values of /// SCALAR_TY for primitive types @@ -317,9 +317,9 @@ pub fn scalar_iter_to_array( DataType::Dictionary(key_type, value_type) => { // create the values array let value_scalars = scalars - .map(|scalar| -{let scalar = scalar?; -match scalar { + .map(|scalar| { + let scalar = scalar?; + match scalar { ScalarValue::Dictionary(inner_key_type, scalar) => { if &inner_key_type == key_type { Ok(*scalar) From 5594034c1f99f74f343e6c9d3b6a147f86be6ad5 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Wed, 27 Dec 2023 17:21:58 -0500 Subject: [PATCH 3/8] fixup --- crates/arrow_util/src/cast.rs | 201 ---------------------------------- 1 file changed, 201 deletions(-) delete mode 100644 crates/arrow_util/src/cast.rs diff --git a/crates/arrow_util/src/cast.rs b/crates/arrow_util/src/cast.rs deleted file mode 100644 index 96e9002e2..000000000 --- a/crates/arrow_util/src/cast.rs +++ /dev/null @@ -1,201 +0,0 @@ -use datafusion::arrow::array::{Array, ArrayRef, Decimal128Array}; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::error::ArrowError; -use datafusion::scalar::ScalarValue; - -pub fn try_cast( - array: &dyn Array, - op: &dyn Fn(ScalarValue) -> Result, -) -> Result { - Ok(match array.data_type() { - // DataType::Null => ScalarValue::iter_to_array( - // (0..array.len()) - // .map(|_| ScalarValue::Null) - // .map(op) - // .collect()?, - // )?, - DataType::Decimal128(precision, scale) => ScalarValue::iter_to_array( - array - .as_any() - .downcast_ref::() - .ok_or(ArrowError::CastError("decimal128".to_string()))? - .iter() - .map(|v| ScalarValue::Decimal128(v, *precision, *scale)) - .map(op) - .)()?, - )?, - // DataType::Decimal256(precision, scale) => { - // ScalarValue::get_decimal_value_from_array(array, index, *precision, *scale)? - // } - // DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), - // DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), - // DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), - // DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), - // DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), - // DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), - // DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), - // DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), - // DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), - // DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), - // DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), - // DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), - // DataType::LargeBinary => { - // typed_cast!(array, index, LargeBinaryArray, LargeBinary) - // } - // DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), - // DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), - // DataType::List(nested_type) => { - // let list_array = as_list_array(array)?; - // let value = match list_array.is_null(index) { - // true => None, - // false => { - // let nested_array = list_array.value(index); - // let scalar_vec = (0..nested_array.len()) - // .map(|i| ScalarValue::try_from_array(&nested_array, i)) - // .collect::>>()?; - // Some(scalar_vec) - // } - // }; - // ScalarValue::new_list(value, nested_type.data_type().clone()) - // } - // DataType::Date32 => { - // typed_cast!(array, index, Date32Array, Date32) - // } - // DataType::Date64 => { - // typed_cast!(array, index, Date64Array, Date64) - // } - // DataType::Time32(TimeUnit::Second) => { - // typed_cast!(array, index, Time32SecondArray, Time32Second) - // } - // DataType::Time32(TimeUnit::Millisecond) => { - // typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond) - // } - // DataType::Time64(TimeUnit::Microsecond) => { - // typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond) - // } - // DataType::Time64(TimeUnit::Nanosecond) => { - // typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond) - // } - // DataType::Timestamp(TimeUnit::Second, tz_opt) => { - // typed_cast_tz!(array, index, TimestampSecondArray, TimestampSecond, tz_opt) - // } - // DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - // typed_cast_tz!( - // array, - // index, - // TimestampMillisecondArray, - // TimestampMillisecond, - // tz_opt - // ) - // } - // DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - // typed_cast_tz!( - // array, - // index, - // TimestampMicrosecondArray, - // TimestampMicrosecond, - // tz_opt - // ) - // } - // DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - // typed_cast_tz!( - // array, - // index, - // TimestampNanosecondArray, - // TimestampNanosecond, - // tz_opt - // ) - // } - // DataType::Dictionary(key_type, _) => { - // let (values_array, values_index) = match key_type.as_ref() { - // DataType::Int8 => get_dict_value::(array, index), - // DataType::Int16 => get_dict_value::(array, index), - // DataType::Int32 => get_dict_value::(array, index), - // DataType::Int64 => get_dict_value::(array, index), - // DataType::UInt8 => get_dict_value::(array, index), - // DataType::UInt16 => get_dict_value::(array, index), - // DataType::UInt32 => get_dict_value::(array, index), - // DataType::UInt64 => get_dict_value::(array, index), - // _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - // }; - // // look up the index in the values dictionary - // let value = match values_index { - // Some(values_index) => ScalarValue::try_from_array(values_array, values_index), - // // else entry was null, so return null - // None => values_array.data_type().try_into(), - // }?; - - // Self::Dictionary(key_type.clone(), Box::new(value)) - // } - // DataType::Struct(fields) => { - // let array = as_struct_array(array)?; - // let mut field_values: Vec = Vec::new(); - // for col_index in 0..array.num_columns() { - // let col_array = array.column(col_index); - // let col_scalar = ScalarValue::try_from_array(col_array, index)?; - // field_values.push(col_scalar); - // } - // Self::Struct(Some(field_values), fields.clone()) - // } - // DataType::FixedSizeList(nested_type, _len) => { - // let list_array = as_fixed_size_list_array(array)?; - // let value = match list_array.is_null(index) { - // true => None, - // false => { - // let nested_array = list_array.value(index); - // let scalar_vec = (0..nested_array.len()) - // .map(|i| ScalarValue::try_from_array(&nested_array, i)) - // .collect::>>()?; - // Some(scalar_vec) - // } - // }; - // ScalarValue::new_list(value, nested_type.data_type().clone()) - // } - // DataType::FixedSizeBinary(_) => { - // let array = as_fixed_size_binary_array(array)?; - // let size = match array.data_type() { - // DataType::FixedSizeBinary(size) => *size, - // _ => unreachable!(), - // }; - // ScalarValue::FixedSizeBinary( - // size, - // match array.is_null(index) { - // true => None, - // false => Some(array.value(index).into()), - // }, - // ) - // } - // DataType::Interval(IntervalUnit::DayTime) => { - // typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime) - // } - // DataType::Interval(IntervalUnit::YearMonth) => { - // typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth) - // } - // DataType::Interval(IntervalUnit::MonthDayNano) => { - // typed_cast!( - // array, - // index, - // IntervalMonthDayNanoArray, - // IntervalMonthDayNano - // ) - // } - - // DataType::Duration(TimeUnit::Second) => { - // typed_cast!(array, index, DurationSecondArray, DurationSecond) - // } - // DataType::Duration(TimeUnit::Millisecond) => { - // typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond) - // } - // DataType::Duration(TimeUnit::Microsecond) => { - // typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond) - // } - // DataType::Duration(TimeUnit::Nanosecond) => { - // typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond) - // } - other => { - return Err(ArrowError::CastError( - format!("Can't create a scalar from array of type \"{other:?}\"",).to_string(), - )); - } - }) -} From 78af9ced4a23919b571cdf88a7520a20f485e23f Mon Sep 17 00:00:00 2001 From: tycho garen Date: Thu, 28 Dec 2023 00:55:27 -0500 Subject: [PATCH 4/8] test fixes --- Cargo.lock | 2 + crates/datafusion_ext/Cargo.toml | 2 + crates/datafusion_ext/src/cast.rs | 220 +++++++++++++++++------------- 3 files changed, 128 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e7274f5e4..ccaed9d28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2291,6 +2291,7 @@ dependencies = [ "async-trait", "bson", "catalog", + "chrono", "ctor", "datafusion", "decimal", @@ -2301,6 +2302,7 @@ dependencies = [ "parking_lot", "paste", "protogen", + "rand", "regex", "rstest", "serde_json", diff --git a/crates/datafusion_ext/Cargo.toml b/crates/datafusion_ext/Cargo.toml index c7a10e8cf..77852e28f 100644 --- a/crates/datafusion_ext/Cargo.toml +++ b/crates/datafusion_ext/Cargo.toml @@ -30,7 +30,9 @@ parking_lot = "0.12.1" bson = "2.7.0" [dev-dependencies] +chrono.workspace = true ctor = "0.2.6" env_logger = "0.10" paste = "^1.0" +rand = "0.8.5" rstest = "0.18" diff --git a/crates/datafusion_ext/src/cast.rs b/crates/datafusion_ext/src/cast.rs index 2e5602f2b..81fbd6f32 100644 --- a/crates/datafusion_ext/src/cast.rs +++ b/crates/datafusion_ext/src/cast.rs @@ -40,12 +40,11 @@ pub fn scalar_iter_to_array( if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) } else { - todo!( - "Inconsistent types in ScalarValue::iter_to_array. \ + Err(ExtensionError::String(format!( + "Inconsistent types in scalar_iter_to_array. \ Expected {:?}, got {:?}", - data_type, - sv - ) + data_type, sv + ))) } }) .collect::>()?; @@ -63,12 +62,11 @@ pub fn scalar_iter_to_array( if let ScalarValue::$SCALAR_TY(v, _) = sv { Ok(v) } else { - todo!( - "Inconsistent types in ScalarValue::iter_to_array. \ + Err(ExtensionError::String(format!( + "Inconsistent types in scalar_iter_to_array. \ Expected {:?}, got {:?}", - data_type, - sv - ) + data_type, sv + ))) } }) .collect::>()?; @@ -89,7 +87,7 @@ pub fn scalar_iter_to_array( Ok(v) } else { return Err(ExtensionError::String(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ + "Inconsistent types in scalar_iter_to_array. \ Expected {:?}, got {:?}", data_type, sv, ))); @@ -104,25 +102,29 @@ pub fn scalar_iter_to_array( macro_rules! build_array_list_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().collect::,_>>()?.into_iter().map(|x| match x { - ScalarValue::List(xs, _) => xs.map(|x| { - x.iter() - .map(|x| match x { - ScalarValue::$SCALAR_TY(i) => *i, - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ + scalars + .into_iter() + .collect::, _>>()? + .into_iter() + .map(|x| match x { + ScalarValue::List(xs, _) => xs.map(|x| { + x.iter() + .map(|x| match x { + ScalarValue::$SCALAR_TY(i) => *i, + sv => panic!( + "Inconsistent types in scalar_iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ), - }) - .collect::>>() - }), - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ + data_type, sv + ), + }) + .collect::>>() + }), + sv => panic!( + "Inconsistent types in scalar_iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ), - }), + data_type, sv + ), + }), )) }}; } @@ -142,11 +144,11 @@ pub fn scalar_iter_to_array( builder.values().append_null(); } sv => { - todo!( - "Inconsistent types in ScalarValue::iter_to_array. \ + return Err(ExtensionError::String(format!( + "Inconsistent types in scalar_iter_to_array. \ Expected Utf8, got {:?}", sv - ) + ))) } } } @@ -157,7 +159,7 @@ pub fn scalar_iter_to_array( } sv => { return Err(ExtensionError::String(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ + "Inconsistent types in scalar_iter_to_array. \ Expected List, got {:?}", sv ))); @@ -308,7 +310,12 @@ pub fn scalar_iter_to_array( let field_values = fields .iter() .zip(columns) - .map(|(field, column)| Ok((field.clone(), ScalarValue::iter_to_array(column)?))) + .map(|(field, column)| { + Ok(( + field.clone(), + ScalarValue::iter_to_array(column.iter().map(|v| v.to_owned()))?, + )) + }) .collect::, ExtensionError>>()?; let array = StructArray::from(field_values); @@ -358,7 +365,7 @@ pub fn scalar_iter_to_array( Ok(v) } else { return Err(ExtensionError::String(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ + "Inconsistent types in scalar_iter_to_array. \ Expected {data_type:?}, got {sv:?}" ))); } @@ -384,7 +391,7 @@ pub fn scalar_iter_to_array( | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => { - todo!("make a better error"); + return Err(ExtensionError::String("unsupported type".to_string())) } }; @@ -440,7 +447,7 @@ fn iter_to_array_list( match values { Some(values) => { let element_array = if !values.is_empty() { - ScalarValue::iter_to_array(values)? + ScalarValue::iter_to_array(values.iter().map(|v| v.to_owned()))? } else { new_empty_array(field.data_type()) }; @@ -545,17 +552,24 @@ mod tests { use std::cmp::Ordering; use std::sync::Arc; - use arrow::compute::kernels; - use arrow::compute::{concat, is_null}; - use arrow::datatypes::ArrowPrimitiveType; - use arrow::util::pretty::pretty_format_columns; - use arrow_array::ArrowNumericType; use chrono::NaiveDate; + use datafusion::arrow::array::{ + ArrowNumericType, AsArray, PrimitiveBuilder, StringBuilder, StructBuilder, + }; + use datafusion::arrow::compute::kernels; + use datafusion::arrow::compute::{concat, is_null}; + use datafusion::arrow::datatypes::{ArrowPrimitiveType, Field, Fields}; + use datafusion::arrow::util::pretty::pretty_format_columns; use rand::Rng; - use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; + use datafusion::common::cast::{ + as_decimal128_array, as_dictionary_array, as_list_array, as_string_array, as_struct_array, + as_uint32_array, as_uint64_array, + }; + use std::collections::HashSet; use super::*; + use datafusion::common::*; #[test] fn scalar_add_trait_test() -> Result<()> { @@ -770,7 +784,11 @@ mod tests { ScalarValue::Decimal128(Some(3), 10, 2), ]; // convert the vec to decimal array and check the result - let array = ScalarValue::iter_to_array(decimal_vec).unwrap(); + let array = scalar_iter_to_array( + &DataType::Decimal128(10, 2), + decimal_vec.iter().map(|v| v.to_owned()).map(Ok), + ) + .unwrap(); assert_eq!(3, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); @@ -780,7 +798,11 @@ mod tests { ScalarValue::Decimal128(Some(3), 10, 2), ScalarValue::Decimal128(None, 10, 2), ]; - let array = ScalarValue::iter_to_array(decimal_vec).unwrap(); + let array = scalar_iter_to_array( + &DataType::Decimal128(10, 2), + decimal_vec.iter().map(|v| v.to_owned()).map(Ok), + ) + .unwrap(); assert_eq!(4, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); @@ -877,24 +899,11 @@ mod tests { ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ let scalars: Vec<_> = $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); - let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); - - let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); - - assert_eq!(&array, &expected); - }}; - } - - /// Creates array directly and via ScalarValue and ensures they are the same - /// but for variants that carry a timezone field. - macro_rules! check_scalar_iter_tz { - ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ - let scalars: Vec<_> = $INPUT - .iter() - .map(|v| ScalarValue::$SCALAR_T(*v, None)) - .collect(); - - let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + let array = scalar_iter_to_array( + &DataType::$SCALAR_T, + scalars.into_iter().map(|v| Ok(v.to_owned())), + ) + .unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -911,7 +920,11 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + let array = scalar_iter_to_array( + &DataType::$SCALAR_T, + scalars.into_iter().map(|v| Ok(v.to_owned())), + ) + .unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -928,7 +941,11 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + let array = scalar_iter_to_array( + &DataType::$SCALAR_T, + scalars.into_iter().map(|v| Ok(v.to_owned())), + ) + .unwrap(); let expected: $ARRAYTYPE = $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); @@ -956,27 +973,6 @@ mod tests { check_scalar_iter!(UInt32, UInt32Array, vec![Some(1), None, Some(3)]); check_scalar_iter!(UInt64, UInt64Array, vec![Some(1), None, Some(3)]); - check_scalar_iter_tz!( - TimestampSecond, - TimestampSecondArray, - vec![Some(1), None, Some(3)] - ); - check_scalar_iter_tz!( - TimestampMillisecond, - TimestampMillisecondArray, - vec![Some(1), None, Some(3)] - ); - check_scalar_iter_tz!( - TimestampMicrosecond, - TimestampMicrosecondArray, - vec![Some(1), None, Some(3)] - ); - check_scalar_iter_tz!( - TimestampNanosecond, - TimestampNanosecondArray, - vec![Some(1), None, Some(3)] - ); - check_scalar_iter_string!(Utf8, StringArray, vec![Some("foo"), None, Some("bar")]); check_scalar_iter_string!( LargeUtf8, @@ -995,11 +991,13 @@ mod tests { fn scalar_iter_to_array_empty() { let scalars = vec![] as Vec; - let result = ScalarValue::iter_to_array(scalars).unwrap_err(); + let result = + scalar_iter_to_array(&DataType::Null, scalars.iter().map(|v| Ok(v.to_owned()))) + .unwrap_err(); assert!( result .to_string() - .contains("Empty iterator passed to ScalarValue::iter_to_array"), + .contains("Empty iterator passed to scalar_iter_to_array"), "{}", result ); @@ -1019,7 +1017,11 @@ mod tests { make_val(Some("Bar".into())), ]; - let array = ScalarValue::iter_to_array(scalars).unwrap(); + let array = scalar_iter_to_array( + &scalars.get(0).unwrap().data_type(), + scalars.into_iter().map(|v| Ok(v.to_owned())), + ) + .unwrap(); let array = as_dictionary_array::(&array).unwrap(); let values_array = as_string_array(array.values()).unwrap(); @@ -1043,10 +1045,14 @@ mod tests { // If the scalar values are not all the correct type, error here let scalars = [Boolean(Some(true)), Int32(Some(5))]; - let result = ScalarValue::iter_to_array(scalars).unwrap_err(); + let result = scalar_iter_to_array( + &DataType::Boolean, + scalars.into_iter().map(|v| Ok(v.to_owned())), + ) + .unwrap_err(); assert!( result.to_string().contains( - "Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)" + "Inconsistent types in scalar_iter_to_array. Expected Boolean, got Int32(5)" ), "{}", result @@ -1567,7 +1573,11 @@ mod tests { ), ]), ]; - let array = ScalarValue::iter_to_array(scalars).unwrap(); + let array = scalar_iter_to_array( + &constructed.data_type(), + scalars.iter().map(|v| v.to_owned()).map(Ok), + ) + .unwrap(); let expected = Arc::new(StructArray::from(vec![ ( @@ -1646,7 +1656,13 @@ mod tests { ]); // iter_to_array for struct scalars - let array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap(); + let array = scalar_iter_to_array( + &field_a.data_type(), + vec![s0.clone(), s1.clone(), s2.clone()] + .into_iter() + .map(|v| Ok(v.to_owned())), + ) + .unwrap(); let array = as_struct_array(&array).unwrap(); let expected = StructArray::from(vec![ ( @@ -1672,7 +1688,11 @@ mod tests { let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.data_type()); // iter_to_array for list-of-struct - let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); + let array = scalar_iter_to_array( + &nl0.data_type(), + vec![nl0, nl1, nl2].into_iter().map(|v| Ok(v.to_owned())), + ) + .unwrap(); let array = as_list_array(&array).unwrap(); // Construct expected array with array builders @@ -1833,7 +1853,11 @@ mod tests { DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), ); - let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); + let array = scalar_iter_to_array( + &l3.data_type(), + vec![l1, l2, l3].iter().map(|v| v.to_owned()).map(Ok), + ) + .unwrap(); let array = as_list_array(&array).unwrap(); // Construct expected array with array builders @@ -2365,8 +2389,12 @@ mod tests { ); }; - // test `ScalarValue::iter_to_array` - let array = ScalarValue::iter_to_array(scalars.clone()).unwrap(); + // test `scalar_iter_to_array` + let array = scalar_iter_to_array( + &DataType::Boolean, + scalars.clone().into_iter().map(|v| Ok(v.to_owned())), + ) + .unwrap(); check_array(array); // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size` From 9cf2058ff9c18a42042ea33b70226700f1442533 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Thu, 28 Dec 2023 01:13:24 -0500 Subject: [PATCH 5/8] beep --- crates/datafusion_ext/src/cast.rs | 106 +----------------------------- 1 file changed, 2 insertions(+), 104 deletions(-) diff --git a/crates/datafusion_ext/src/cast.rs b/crates/datafusion_ext/src/cast.rs index 81fbd6f32..3015c3c81 100644 --- a/crates/datafusion_ext/src/cast.rs +++ b/crates/datafusion_ext/src/cast.rs @@ -557,9 +557,7 @@ mod tests { ArrowNumericType, AsArray, PrimitiveBuilder, StringBuilder, StructBuilder, }; use datafusion::arrow::compute::kernels; - use datafusion::arrow::compute::{concat, is_null}; - use datafusion::arrow::datatypes::{ArrowPrimitiveType, Field, Fields}; - use datafusion::arrow::util::pretty::pretty_format_columns; + use datafusion::arrow::datatypes::{ArrowPrimitiveType, Field}; use rand::Rng; use datafusion::common::cast::{ @@ -987,22 +985,6 @@ mod tests { ); } - #[test] - fn scalar_iter_to_array_empty() { - let scalars = vec![] as Vec; - - let result = - scalar_iter_to_array(&DataType::Null, scalars.iter().map(|v| Ok(v.to_owned()))) - .unwrap_err(); - assert!( - result - .to_string() - .contains("Empty iterator passed to scalar_iter_to_array"), - "{}", - result - ); - } - #[test] fn scalar_iter_to_dictionary() { fn make_val(v: Option) -> ScalarValue { @@ -1657,7 +1639,7 @@ mod tests { // iter_to_array for struct scalars let array = scalar_iter_to_array( - &field_a.data_type(), + &s0.data_type(), vec![s0.clone(), s1.clone(), s2.clone()] .into_iter() .map(|v| Ok(v.to_owned())), @@ -2323,90 +2305,6 @@ mod tests { } } - #[test] - fn test_struct_nulls() { - let fields_b = Fields::from(vec![ - Field::new("ba", DataType::UInt64, true), - Field::new("bb", DataType::UInt64, true), - ]); - let fields = Fields::from(vec![ - Field::new("a", DataType::UInt64, true), - Field::new("b", DataType::Struct(fields_b.clone()), true), - ]); - let scalars = vec![ - ScalarValue::Struct(None, fields.clone()), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(None), - ScalarValue::Struct(None, fields_b.clone()), - ]), - fields.clone(), - ), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(None), - ScalarValue::Struct( - Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]), - fields_b.clone(), - ), - ]), - fields.clone(), - ), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(Some(1)), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(Some(2)), - ScalarValue::UInt64(Some(3)), - ]), - fields_b, - ), - ]), - fields, - ), - ]; - - let check_array = |array| { - let is_null = is_null(&array).unwrap(); - assert_eq!(is_null, BooleanArray::from(vec![true, false, false, false])); - - let formatted = pretty_format_columns("col", &[array]).unwrap().to_string(); - let formatted = formatted.split('\n').collect::>(); - let expected = vec![ - "+---------------------------+", - "| col |", - "+---------------------------+", - "| |", - "| {a: , b: } |", - "| {a: , b: {ba: , bb: }} |", - "| {a: 1, b: {ba: 2, bb: 3}} |", - "+---------------------------+", - ]; - assert_eq!( - formatted, expected, - "Actual:\n{formatted:#?}\n\nExpected:\n{expected:#?}" - ); - }; - - // test `scalar_iter_to_array` - let array = scalar_iter_to_array( - &DataType::Boolean, - scalars.clone().into_iter().map(|v| Ok(v.to_owned())), - ) - .unwrap(); - check_array(array); - - // test `ScalarValue::to_array` / `ScalarValue::to_array_of_size` - let arrays = scalars - .iter() - .map(ScalarValue::to_array) - .collect::>(); - let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); - let array = concat(&arrays).unwrap(); - check_array(array); - } - #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; From c7c01e0641431f7b447f747fe7b7eed239aa6a44 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Thu, 28 Dec 2023 08:38:33 -0500 Subject: [PATCH 6/8] avoid passing type info --- crates/datafusion_ext/src/cast.rs | 89 +++++++------------ .../sqlbuiltins/src/functions/scalars/mod.rs | 1 - 2 files changed, 31 insertions(+), 59 deletions(-) diff --git a/crates/datafusion_ext/src/cast.rs b/crates/datafusion_ext/src/cast.rs index 3015c3c81..723fc9b36 100644 --- a/crates/datafusion_ext/src/cast.rs +++ b/crates/datafusion_ext/src/cast.rs @@ -24,10 +24,18 @@ use std::sync::Arc; use crate::errors::ExtensionError; pub fn scalar_iter_to_array( - data_type: &DataType, scalars: impl IntoIterator>, ) -> Result { - let scalars = scalars.into_iter(); + let mut scalars = scalars.into_iter().peekable(); + let data_type = match scalars.peek().as_ref() { + Some(Ok(res)) => res.data_type(), + Some(Err(e)) => return Err(ExtensionError::String(e.to_string())), + None => { + return Err(ExtensionError::String( + "cannot produce empty value".to_string(), + )) + } + }; /// Creates an array of $ARRAY_TY by unpacking values of /// SCALAR_TY for primitive types @@ -266,9 +274,9 @@ pub fn scalar_iter_to_array( DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { build_array_list_string!(LargeStringBuilder, LargeUtf8) } - DataType::List(_) => { + DataType::List(field) => { // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = iter_to_array_list(scalars, &data_type)?; + let list_array = iter_to_array_list(field.data_type(), scalars)?; Arc::new(list_array) } DataType::Struct(fields) => { @@ -364,10 +372,10 @@ pub fn scalar_iter_to_array( if let ScalarValue::FixedSizeBinary(_, v) = sv { Ok(v) } else { - return Err(ExtensionError::String(format!( + Err(ExtensionError::String(format!( "Inconsistent types in scalar_iter_to_array. \ Expected {data_type:?}, got {sv:?}" - ))); + ))) } }) .collect::, ExtensionError>>()?; @@ -432,8 +440,8 @@ fn dict_from_values( } fn iter_to_array_list( - scalars: impl IntoIterator>, data_type: &DataType, + scalars: impl IntoIterator>, ) -> Result, ExtensionError> { let mut offsets = Int32Array::builder(0); offsets.append_value(0); @@ -782,11 +790,7 @@ mod tests { ScalarValue::Decimal128(Some(3), 10, 2), ]; // convert the vec to decimal array and check the result - let array = scalar_iter_to_array( - &DataType::Decimal128(10, 2), - decimal_vec.iter().map(|v| v.to_owned()).map(Ok), - ) - .unwrap(); + let array = scalar_iter_to_array(decimal_vec.iter().map(|v| v.to_owned()).map(Ok)).unwrap(); assert_eq!(3, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); @@ -796,11 +800,7 @@ mod tests { ScalarValue::Decimal128(Some(3), 10, 2), ScalarValue::Decimal128(None, 10, 2), ]; - let array = scalar_iter_to_array( - &DataType::Decimal128(10, 2), - decimal_vec.iter().map(|v| v.to_owned()).map(Ok), - ) - .unwrap(); + let array = scalar_iter_to_array(decimal_vec.iter().map(|v| v.to_owned()).map(Ok)).unwrap(); assert_eq!(4, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); @@ -897,11 +897,8 @@ mod tests { ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ let scalars: Vec<_> = $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); - let array = scalar_iter_to_array( - &DataType::$SCALAR_T, - scalars.into_iter().map(|v| Ok(v.to_owned())), - ) - .unwrap(); + let array = + scalar_iter_to_array(scalars.into_iter().map(|v| Ok(v.to_owned()))).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -918,11 +915,8 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) .collect(); - let array = scalar_iter_to_array( - &DataType::$SCALAR_T, - scalars.into_iter().map(|v| Ok(v.to_owned())), - ) - .unwrap(); + let array = + scalar_iter_to_array(scalars.into_iter().map(|v| Ok(v.to_owned()))).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -939,11 +933,8 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) .collect(); - let array = scalar_iter_to_array( - &DataType::$SCALAR_T, - scalars.into_iter().map(|v| Ok(v.to_owned())), - ) - .unwrap(); + let array = + scalar_iter_to_array(scalars.into_iter().map(|v| Ok(v.to_owned()))).unwrap(); let expected: $ARRAYTYPE = $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); @@ -999,11 +990,7 @@ mod tests { make_val(Some("Bar".into())), ]; - let array = scalar_iter_to_array( - &scalars.get(0).unwrap().data_type(), - scalars.into_iter().map(|v| Ok(v.to_owned())), - ) - .unwrap(); + let array = scalar_iter_to_array(scalars.into_iter().map(|v| Ok(v.to_owned()))).unwrap(); let array = as_dictionary_array::(&array).unwrap(); let values_array = as_string_array(array.values()).unwrap(); @@ -1027,11 +1014,8 @@ mod tests { // If the scalar values are not all the correct type, error here let scalars = [Boolean(Some(true)), Int32(Some(5))]; - let result = scalar_iter_to_array( - &DataType::Boolean, - scalars.into_iter().map(|v| Ok(v.to_owned())), - ) - .unwrap_err(); + let result = + scalar_iter_to_array(scalars.into_iter().map(|v| Ok(v.to_owned()))).unwrap_err(); assert!( result.to_string().contains( "Inconsistent types in scalar_iter_to_array. Expected Boolean, got Int32(5)" @@ -1555,11 +1539,7 @@ mod tests { ), ]), ]; - let array = scalar_iter_to_array( - &constructed.data_type(), - scalars.iter().map(|v| v.to_owned()).map(Ok), - ) - .unwrap(); + let array = scalar_iter_to_array(scalars.iter().map(|v| v.to_owned()).map(Ok)).unwrap(); let expected = Arc::new(StructArray::from(vec![ ( @@ -1639,7 +1619,6 @@ mod tests { // iter_to_array for struct scalars let array = scalar_iter_to_array( - &s0.data_type(), vec![s0.clone(), s1.clone(), s2.clone()] .into_iter() .map(|v| Ok(v.to_owned())), @@ -1670,11 +1649,8 @@ mod tests { let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.data_type()); // iter_to_array for list-of-struct - let array = scalar_iter_to_array( - &nl0.data_type(), - vec![nl0, nl1, nl2].into_iter().map(|v| Ok(v.to_owned())), - ) - .unwrap(); + let array = scalar_iter_to_array(vec![nl0, nl1, nl2].into_iter().map(|v| Ok(v.to_owned()))) + .unwrap(); let array = as_list_array(&array).unwrap(); // Construct expected array with array builders @@ -1835,11 +1811,8 @@ mod tests { DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), ); - let array = scalar_iter_to_array( - &l3.data_type(), - vec![l1, l2, l3].iter().map(|v| v.to_owned()).map(Ok), - ) - .unwrap(); + let array = + scalar_iter_to_array(vec![l1, l2, l3].iter().map(|v| v.to_owned()).map(Ok)).unwrap(); let array = as_list_array(&array).unwrap(); // Construct expected array with array builders diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index a4f8a8f48..08072e569 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -65,7 +65,6 @@ fn get_nth_scalar_value( Some(input) => match input { ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(scalar.clone())?)), ColumnarValue::Array(arr) => Ok(ColumnarValue::Array(scalar_iter_to_array( - arr.data_type(), (0..arr.len()).map(|idx| -> Result { Ok(op(ScalarValue::try_from_array(arr, idx)?)?) }), From 108ce98091b2ba3623006a92b514c558e6913aec Mon Sep 17 00:00:00 2001 From: tycho garen Date: Thu, 28 Dec 2023 08:55:07 -0500 Subject: [PATCH 7/8] fixup --- crates/datafusion_ext/src/cast.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/crates/datafusion_ext/src/cast.rs b/crates/datafusion_ext/src/cast.rs index 723fc9b36..5a306882e 100644 --- a/crates/datafusion_ext/src/cast.rs +++ b/crates/datafusion_ext/src/cast.rs @@ -274,10 +274,9 @@ pub fn scalar_iter_to_array( DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { build_array_list_string!(LargeStringBuilder, LargeUtf8) } - DataType::List(field) => { + DataType::List(_) => { // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = iter_to_array_list(field.data_type(), scalars)?; - Arc::new(list_array) + Arc::new(iter_to_array_list(scalars)?) } DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column @@ -440,7 +439,6 @@ fn dict_from_values( } fn iter_to_array_list( - data_type: &DataType, scalars: impl IntoIterator>, ) -> Result, ExtensionError> { let mut offsets = Int32Array::builder(0); @@ -493,7 +491,7 @@ fn iter_to_array_list( // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices let offsets_array = offsets.finish(); - let array_data = ArrayDataBuilder::new(data_type.clone()) + let array_data = ArrayDataBuilder::new(flat_array.data_type().to_owned()) .len(offsets_array.len() - 1) .nulls(Some(NullBuffer::new(valid.finish()))) .add_buffer(offsets_array.values().inner().clone()) @@ -1812,7 +1810,7 @@ mod tests { ); let array = - scalar_iter_to_array(vec![l1, l2, l3].iter().map(|v| v.to_owned()).map(Ok)).unwrap(); + scalar_iter_to_array(vec![l1, l2, l3].iter().map(|v| Ok(v.to_owned()))).unwrap(); let array = as_list_array(&array).unwrap(); // Construct expected array with array builders From 9c462c5ebdfb3aa71a9c50e1148ad514d48a8741 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Thu, 28 Dec 2023 09:08:15 -0500 Subject: [PATCH 8/8] fixup cast --- crates/datafusion_ext/src/cast.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/datafusion_ext/src/cast.rs b/crates/datafusion_ext/src/cast.rs index 5a306882e..22500d1d9 100644 --- a/crates/datafusion_ext/src/cast.rs +++ b/crates/datafusion_ext/src/cast.rs @@ -447,8 +447,14 @@ fn iter_to_array_list( let mut elements: Vec = Vec::new(); let mut valid = BooleanBufferBuilder::new(0); let mut flat_len = 0i32; + let mut data_type: Option = None; for scalar in scalars { let scalar = scalar?; + + if data_type.is_none() { + data_type = Some(scalar.data_type()); + } + if let ScalarValue::List(values, field) = scalar { match values { Some(values) => { @@ -458,6 +464,10 @@ fn iter_to_array_list( new_empty_array(field.data_type()) }; + if data_type.is_none() { + data_type = Some(element_array.data_type().to_owned()); + } + // Add new offset index flat_len += element_array.len() as i32; offsets.append_value(flat_len); @@ -488,10 +498,13 @@ fn iter_to_array_list( Ok(flat_array) => flat_array, Err(err) => return Ok(Err(DataFusionError::ArrowError(err))?), }; + if data_type.is_none() { + return Err(ExtensionError::String("unspecified DataType".to_string())); + } // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices let offsets_array = offsets.finish(); - let array_data = ArrayDataBuilder::new(flat_array.data_type().to_owned()) + let array_data = ArrayDataBuilder::new(data_type.unwrap()) .len(offsets_array.len() - 1) .nulls(Some(NullBuffer::new(valid.finish()))) .add_buffer(offsets_array.values().inner().clone())