Skip to content

Commit

Permalink
support LargeList for arrow_cast, support ScalarValue::LargeList (
Browse files Browse the repository at this point in the history
#8290)

* support largelist for arrow_cast

* fix cli

* update tests;

* add new_large_list in ScalarValue

* fix ci

* support LargeList in scalar

* modify comment

* support largelist for proto
  • Loading branch information
Weijun-H authored Nov 22, 2023
1 parent 3dbda1e commit f2b0344
Show file tree
Hide file tree
Showing 10 changed files with 546 additions and 44 deletions.
461 changes: 421 additions & 40 deletions datafusion/common/src/scalar.rs

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow::compute;
use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow_array::{Array, ListArray};
use arrow_array::{Array, LargeListArray, ListArray};
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -349,6 +349,18 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
)
}

/// Wrap an array into a single element `LargeListArray`.
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
LargeListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
offsets,
arr,
None,
)
}

/// Wrap arrays into a single element `ListArray`.
///
/// Example:
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,7 @@ message ScalarValue{
// Literal Date32 value always has a unit of day
int32 date_32_value = 14;
ScalarTime32Value time32_value = 15;
ScalarListValue large_list_value = 16;
ScalarListValue list_value = 17;
ScalarListValue fixed_size_list_value = 18;

Expand Down
14 changes: 14 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::Float64Value(v) => Self::Float64(Some(*v)),
Value::Date32Value(v) => Self::Date32(Some(*v)),
// ScalarValue::List is serialized using arrow IPC format
Value::ListValue(scalar_list) | Value::FixedSizeListValue(scalar_list) => {
Value::ListValue(scalar_list)
| Value::FixedSizeListValue(scalar_list)
| Value::LargeListValue(scalar_list) => {
let protobuf::ScalarListValue {
ipc_message,
arrow_data,
Expand Down Expand Up @@ -703,6 +705,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
let arr = record_batch.column(0);
match value {
Value::ListValue(_) => Self::List(arr.to_owned()),
Value::LargeListValue(_) => Self::LargeList(arr.to_owned()),
Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()),
_ => unreachable!(),
}
Expand Down
9 changes: 8 additions & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
}
// ScalarValue::List and ScalarValue::FixedSizeList are serialized using
// Arrow IPC messages as a single column RecordBatch
ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => {
ScalarValue::List(arr)
| ScalarValue::LargeList(arr)
| ScalarValue::FixedSizeList(arr) => {
// Wrap in a "field_name" column
let batch = RecordBatch::try_from_iter(vec![(
"field_name",
Expand Down Expand Up @@ -1174,6 +1176,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
scalar_list_value,
)),
}),
ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::LargeListValue(
scalar_list_value,
)),
}),
ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::FixedSizeListValue(
scalar_list_value,
Expand Down
30 changes: 30 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ fn round_trip_scalar_values() {
ScalarValue::Utf8(None),
ScalarValue::LargeUtf8(None),
ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)),
ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)),
ScalarValue::Date32(None),
ScalarValue::Boolean(Some(true)),
ScalarValue::Boolean(Some(false)),
Expand Down Expand Up @@ -674,6 +675,16 @@ fn round_trip_scalar_values() {
],
&DataType::Float32,
)),
ScalarValue::LargeList(ScalarValue::new_large_list(
&[
ScalarValue::Float32(Some(-213.1)),
ScalarValue::Float32(None),
ScalarValue::Float32(Some(5.5)),
ScalarValue::Float32(Some(2.0)),
ScalarValue::Float32(Some(1.0)),
],
&DataType::Float32,
)),
ScalarValue::List(ScalarValue::new_list(
&[
ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)),
Expand All @@ -690,6 +701,25 @@ fn round_trip_scalar_values() {
],
&DataType::List(new_arc_field("item", DataType::Float32, true)),
)),
ScalarValue::LargeList(ScalarValue::new_large_list(
&[
ScalarValue::LargeList(ScalarValue::new_large_list(
&[],
&DataType::Float32,
)),
ScalarValue::LargeList(ScalarValue::new_large_list(
&[
ScalarValue::Float32(Some(-213.1)),
ScalarValue::Float32(None),
ScalarValue::Float32(Some(5.5)),
ScalarValue::Float32(Some(2.0)),
ScalarValue::Float32(Some(1.0)),
],
&DataType::Float32,
)),
],
&DataType::LargeList(new_arc_field("item", DataType::Float32, true)),
)),
ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::<
Int32Type,
_,
Expand Down
14 changes: 14 additions & 0 deletions datafusion/sql/src/expr/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ impl<'a> Parser<'a> {
Token::Decimal256 => self.parse_decimal_256(),
Token::Dictionary => self.parse_dictionary(),
Token::List => self.parse_list(),
Token::LargeList => self.parse_large_list(),
tok => Err(make_error(
self.val,
&format!("finding next type, got unexpected '{tok}'"),
Expand All @@ -166,6 +167,16 @@ impl<'a> Parser<'a> {
))))
}

/// Parses the LargeList type
fn parse_large_list(&mut self) -> Result<DataType> {
self.expect_token(Token::LParen)?;
let data_type = self.parse_next_type()?;
self.expect_token(Token::RParen)?;
Ok(DataType::LargeList(Arc::new(Field::new(
"item", data_type, true,
))))
}

/// Parses the next timeunit
fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
match self.next_token()? {
Expand Down Expand Up @@ -496,6 +507,7 @@ impl<'a> Tokenizer<'a> {
"Date64" => Token::SimpleType(DataType::Date64),

"List" => Token::List,
"LargeList" => Token::LargeList,

"Second" => Token::TimeUnit(TimeUnit::Second),
"Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
Expand Down Expand Up @@ -585,13 +597,15 @@ enum Token {
Integer(i64),
DoubleQuotedString(String),
List,
LargeList,
}

impl Display for Token {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Token::SimpleType(t) => write!(f, "{t}"),
Token::List => write!(f, "List"),
Token::LargeList => write!(f, "LargeList"),
Token::Timestamp => write!(f, "Timestamp"),
Token::Time32 => write!(f, "Time32"),
Token::Time64 => write!(f, "Time64"),
Expand Down
38 changes: 38 additions & 0 deletions datafusion/sqllogictest/test_files/arrow_typeof.slt
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,41 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some(

statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone
select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))');


## List


query ?
select arrow_cast('1', 'List(Int64)');
----
[1]

query ?
select arrow_cast(make_array(1, 2, 3), 'List(Int64)');
----
[1, 2, 3]

query T
select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'List(Int64)'));
----
List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })


## LargeList


query ?
select arrow_cast('1', 'LargeList(Int64)');
----
[1]

query ?
select arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)');
----
[1, 2, 3]

query T
select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'));
----
LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })

0 comments on commit f2b0344

Please sign in to comment.