From 78505b1e33ee85171ed5ef6c6a487f93d0de26f7 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 6 Apr 2023 17:40:40 +0800 Subject: [PATCH] fix(udf): fix missing chunks (#9025) Signed-off-by: Runji Wang Co-authored-by: lmatz --- Cargo.lock | 1 + src/common/src/array/arrow.rs | 18 ++++++++++++++++++ src/udf/Cargo.toml | 1 + src/udf/src/lib.rs | 17 ++++++++++++++++- 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index adac570e684fd..4b0ea408da2de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6794,6 +6794,7 @@ dependencies = [ "arrow-array", "arrow-flight", "arrow-schema", + "arrow-select", "futures-util", "madsim-tokio", "madsim-tonic", diff --git a/src/common/src/array/arrow.rs b/src/common/src/array/arrow.rs index 137a3e06e9494..53102dd60f3b5 100644 --- a/src/common/src/array/arrow.rs +++ b/src/common/src/array/arrow.rs @@ -198,6 +198,11 @@ macro_rules! converts { array.iter().collect() } } + impl From<&[$ArrowType]> for $ArrayType { + fn from(arrays: &[$ArrowType]) -> Self { + arrays.iter().flat_map(|a| a.iter()).collect() + } + } }; // convert values using FromIntoArrow ($ArrayType:ty, $ArrowType:ty, @map) => { @@ -218,6 +223,19 @@ macro_rules! converts { .collect() } } + impl From<&[$ArrowType]> for $ArrayType { + fn from(arrays: &[$ArrowType]) -> Self { + arrays + .iter() + .flat_map(|a| a.iter()) + .map(|o| { + o.map(|v| { + <<$ArrayType as Array>::RefItem<'_> as FromIntoArrow>::from_arrow(v) + }) + }) + .collect() + } + } }; } converts!(BoolArray, arrow_array::BooleanArray); diff --git a/src/udf/Cargo.toml b/src/udf/Cargo.toml index 4296f342c9c82..a8c3ea94d526d 100644 --- a/src/udf/Cargo.toml +++ b/src/udf/Cargo.toml @@ -14,6 +14,7 @@ normal = ["workspace-hack"] arrow-array = "36" arrow-flight = "36" arrow-schema = "36" +arrow-select = "36" futures-util = "0.3.25" thiserror = "1" tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] } diff --git a/src/udf/src/lib.rs b/src/udf/src/lib.rs index f35a2f832eb44..0dcb547cf61cb 100644 --- a/src/udf/src/lib.rs +++ b/src/udf/src/lib.rs @@ -72,7 +72,20 @@ impl ArrowFlightUdfClient { /// Call a function. pub async fn call(&self, id: &str, input: RecordBatch) -> Result { let mut output_stream = self.call_stream(id, stream::once(async { input })).await?; - output_stream.next().await.ok_or(Error::NoReturned)? + // TODO: support no output + let head = output_stream.next().await.ok_or(Error::NoReturned)??; + let mut remaining = vec![]; + while let Some(batch) = output_stream.next().await { + remaining.push(batch?); + } + if remaining.is_empty() { + Ok(head) + } else { + Ok(arrow_select::concat::concat_batches( + &head.schema(), + std::iter::once(&head).chain(remaining.iter()), + )?) + } } /// Call a function with streaming input and output. @@ -157,6 +170,8 @@ pub enum Error { expected: String, actual: String, }, + #[error("arrow error: {0}")] + Arrow(#[from] arrow_schema::ArrowError), #[error("UDF service returned no data")] NoReturned, }