From 82af04e954f1bb0bcc7566bf0880798836595b5c Mon Sep 17 00:00:00 2001 From: Kevin Axel Date: Mon, 11 Sep 2023 19:15:40 +0800 Subject: [PATCH 1/4] fix(udf): check udf schema fields num and total records (#12206) Signed-off-by: Kevin Axel --- src/udf/src/error.rs | 3 +++ src/udf/src/external.rs | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/src/udf/src/error.rs b/src/udf/src/error.rs index d787808a32810..f20816ee5b2c0 100644 --- a/src/udf/src/error.rs +++ b/src/udf/src/error.rs @@ -40,6 +40,9 @@ pub enum Error { #[error("UDF service returned no data")] NoReturned, + + #[error("Flight service error: {0}")] + ServiceError(String), } static_assertions::const_assert_eq!(std::mem::size_of::(), 32); diff --git a/src/udf/src/external.rs b/src/udf/src/external.rs index 1a666f4d4e378..585adc7ebec5b 100644 --- a/src/udf/src/external.rs +++ b/src/udf/src/external.rs @@ -49,6 +49,15 @@ impl ArrowFlightUdfClient { let input_num = info.total_records as usize; let full_schema = Schema::try_from(info) .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; + if input_num > full_schema.fields.len() { + return Err(Error::ServiceError(format!( + "function {:?} schema info not consistency: input_num: {}, total_fields: {}", + id, + input_num, + full_schema.fields.len() + ))); + } + let (input_fields, return_fields) = full_schema.fields.split_at(input_num); let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect(); let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect(); From 2ea80bdc76b785656319db34d44378a11bb59f86 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 12 Sep 2023 16:10:16 +0800 Subject: [PATCH 2/4] fix(udf): check the data type returned from UDF server (#12202) Signed-off-by: Runji Wang --- src/common/src/types/mod.rs | 8 ++++++ src/common/src/types/struct_type.rs | 10 +++++++ src/expr/src/expr/expr_udf.rs | 7 +++++ src/expr/src/table_function/user_defined.rs | 30 +++++++++++++++++++++ 4 files changed, 55 insertions(+) diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 7737d76cd48fc..3d5be4c060084 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -431,6 +431,14 @@ impl DataType { } d } + + /// Compares the datatype with another, ignoring nested field names and metadata. + pub fn equals_datatype(&self, other: &DataType) -> bool { + match (self, other) { + (Self::Struct(s1), Self::Struct(s2)) => s1.equals_datatype(s2), + _ => self == other, + } + } } impl From for PbDataType { diff --git a/src/common/src/types/struct_type.rs b/src/common/src/types/struct_type.rs index 340ffa7857c1d..239f506db8267 100644 --- a/src/common/src/types/struct_type.rs +++ b/src/common/src/types/struct_type.rs @@ -117,6 +117,16 @@ impl StructType { .chain(std::iter::repeat("").take(self.0.field_types.len() - self.0.field_names.len())) .zip_eq_debug(self.0.field_types.iter()) } + + /// Compares the datatype with another, ignoring nested field names and metadata. + pub fn equals_datatype(&self, other: &StructType) -> bool { + if self.0.field_types.len() != other.0.field_types.len() { + return false; + } + (self.0.field_types.iter()) + .zip_eq_fast(other.0.field_types.iter()) + .all(|(a, b)| a.equals_datatype(b)) + } } impl Display for StructType { diff --git a/src/expr/src/expr/expr_udf.rs b/src/expr/src/expr/expr_udf.rs index 3bb1583e26224..d0ecf58c3af98 100644 --- a/src/expr/src/expr/expr_udf.rs +++ b/src/expr/src/expr/expr_udf.rs @@ -103,6 +103,13 @@ impl UdfExpression { }; let mut array = ArrayImpl::try_from(arrow_array)?; array.set_bitmap(array.null_bitmap() & vis); + if !array.data_type().equals_datatype(&self.return_type) { + bail!( + "UDF returned {:?}, but expected {:?}", + array.data_type(), + self.return_type, + ); + } Ok(Arc::new(array)) } } diff --git a/src/expr/src/table_function/user_defined.rs b/src/expr/src/table_function/user_defined.rs index 21d3801cce47d..7b0385854c544 100644 --- a/src/expr/src/table_function/user_defined.rs +++ b/src/expr/src/table_function/user_defined.rs @@ -67,9 +67,39 @@ impl UserDefinedTableFunction { .await? { let output = DataChunk::try_from(&res?)?; + self.check_output(&output)?; yield output; } } + + /// Check if the output chunk is valid. + fn check_output(&self, output: &DataChunk) -> Result<()> { + if output.columns().len() != 2 { + bail!( + "UDF returned {} columns, but expected 2", + output.columns().len() + ); + } + if output.column_at(0).data_type() != DataType::Int32 { + bail!( + "UDF returned {:?} at column 0, but expected {:?}", + output.column_at(0).data_type(), + DataType::Int32, + ); + } + if !output + .column_at(1) + .data_type() + .equals_datatype(&self.return_type) + { + bail!( + "UDF returned {:?} at column 1, but expected {:?}", + output.column_at(1).data_type(), + &self.return_type, + ); + } + Ok(()) + } } #[cfg(not(madsim))] From 1616b82e093794663fdc0086ff918a667797d3ac Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Sat, 16 Sep 2023 12:16:02 +0800 Subject: [PATCH 3/4] fix(udf): handle visibility of input chunks in UDTF (#12357) Signed-off-by: Runji Wang --- e2e_test/udf/udf.slt | 29 ++++++++++++++++ src/common/src/array/arrow.rs | 6 ++-- src/common/src/array/data_chunk.rs | 23 +++++++++++++ src/expr/src/table_function/mod.rs | 4 +-- src/expr/src/table_function/user_defined.rs | 38 ++++++++++++++++----- 5 files changed, 88 insertions(+), 12 deletions(-) diff --git a/e2e_test/udf/udf.slt b/e2e_test/udf/udf.slt index 33579a825832e..110b1c0f373dd 100644 --- a/e2e_test/udf/udf.slt +++ b/e2e_test/udf/udf.slt @@ -224,6 +224,35 @@ select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d9 ---- 192.168.0.14 192.168.0.1 861 8374 +# steaming +# to ensure UDF & UDTF respect visibility + +statement ok +create table t (x int); + +statement ok +create materialized view mv as select gcd(x, x), series(x) from t where x <> 2; + +statement ok +insert into t values (1), (2), (3); + +statement ok +flush; + +query II +select * from mv; +---- +1 0 +3 0 +3 1 +3 2 + +statement ok +drop materialized view mv; + +statement ok +drop table t; + # error handling statement error diff --git a/src/common/src/array/arrow.rs b/src/common/src/array/arrow.rs index 9b4165b608d98..0f89e6b4f53f4 100644 --- a/src/common/src/array/arrow.rs +++ b/src/common/src/array/arrow.rs @@ -27,6 +27,7 @@ use crate::util::iter_util::ZipEqDebug; // Implement bi-directional `From` between `DataChunk` and `arrow_array::RecordBatch`. +// note: DataChunk -> arrow RecordBatch will IGNORE the visibilities. impl TryFrom<&DataChunk> for arrow_array::RecordBatch { type Error = ArrayError; @@ -47,8 +48,9 @@ impl TryFrom<&DataChunk> for arrow_array::RecordBatch { .collect(); let schema = Arc::new(Schema::new(fields)); - - arrow_array::RecordBatch::try_new(schema, columns) + let opts = + arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); + arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) .map_err(|err| ArrayError::ToArrow(err.to_string())) } } diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index 8a013fdb6c852..33a3b692415c2 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Cow; +use std::fmt::Display; use std::hash::BuildHasher; use std::sync::Arc; use std::{fmt, usize}; @@ -261,6 +263,27 @@ impl DataChunk { } } + /// Convert the chunk to compact format. + /// + /// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to self. + pub fn compact_cow(&self) -> Cow<'_, Self> { + match &self.vis2 { + Vis::Compact(_) => Cow::Borrowed(self), + Vis::Bitmap(visibility) => { + let cardinality = visibility.count_ones(); + let columns = self + .columns + .iter() + .map(|col| { + let array = col; + array.compact(visibility, cardinality).into() + }) + .collect::>(); + Cow::Owned(Self::new(columns, cardinality)) + } + } + } + pub fn from_protobuf(proto: &PbDataChunk) -> ArrayResult { let mut columns = vec![]; for any_col in proto.get_columns() { diff --git a/src/expr/src/table_function/mod.rs b/src/expr/src/table_function/mod.rs index bf89463adbbdd..594f2831dbc38 100644 --- a/src/expr/src/table_function/mod.rs +++ b/src/expr/src/table_function/mod.rs @@ -49,7 +49,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send { /// # Contract of the output /// /// The returned `DataChunk` contains exact two columns: - /// - The first column is an I32Array containing row indexes of input chunk. It should be + /// - The first column is an I32Array containing row indices of input chunk. It should be /// monotonically increasing. /// - The second column is the output values. The data type of the column is `return_type`. /// @@ -80,7 +80,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send { /// (You don't need to understand this section to implement a `TableFunction`) /// /// The output of the `TableFunction` is different from the output of the `ProjectSet` executor. - /// `ProjectSet` executor uses the row indexes to stitch multiple table functions and produces + /// `ProjectSet` executor uses the row indices to stitch multiple table functions and produces /// `projected_row_id`. /// /// ## Example diff --git a/src/expr/src/table_function/user_defined.rs b/src/expr/src/table_function/user_defined.rs index 7b0385854c544..813cf23504482 100644 --- a/src/expr/src/table_function/user_defined.rs +++ b/src/expr/src/table_function/user_defined.rs @@ -14,9 +14,10 @@ use std::sync::Arc; +use arrow_array::RecordBatch; use arrow_schema::{Field, Fields, Schema, SchemaRef}; use futures_util::stream; -use risingwave_common::array::DataChunk; +use risingwave_common::array::{DataChunk, I32Array}; use risingwave_common::bail; use risingwave_udf::ArrowFlightUdfClient; @@ -25,6 +26,7 @@ use super::*; #[derive(Debug)] pub struct UserDefinedTableFunction { children: Vec, + #[allow(dead_code)] arg_schema: SchemaRef, return_type: DataType, client: Arc, @@ -49,25 +51,42 @@ impl TableFunction for UserDefinedTableFunction { impl UserDefinedTableFunction { #[try_stream(boxed, ok = DataChunk, error = ExprError)] async fn eval_inner<'a>(&'a self, input: &'a DataChunk) { + // evaluate children expressions let mut columns = Vec::with_capacity(self.children.len()); for c in &self.children { - let val = c.eval_checked(input).await?.as_ref().try_into()?; + let val = c.eval_checked(input).await?; columns.push(val); } + let direct_input = DataChunk::new(columns, input.vis().clone()); + + // compact the input chunk and record the row mapping + let visible_rows = direct_input.vis().iter_ones().collect_vec(); + let compacted_input = direct_input.compact_cow(); + let arrow_input = RecordBatch::try_from(compacted_input.as_ref())?; - let opts = - arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality())); - let input = - arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts) - .expect("failed to build record batch"); + // call UDTF #[for_await] for res in self .client - .call_stream(&self.identifier, stream::once(async { input })) + .call_stream(&self.identifier, stream::once(async { arrow_input })) .await? { let output = DataChunk::try_from(&res?)?; self.check_output(&output)?; + + // we send the compacted input to UDF, so we need to map the row indices back to the original input + let origin_indices = output + .column_at(0) + .as_int32() + .raw_iter() + // we have checked all indices are non-negative + .map(|idx| visible_rows[idx as usize] as i32) + .collect::(); + + let output = DataChunk::new( + vec![origin_indices.into_ref(), output.column_at(1).clone()], + output.vis().clone(), + ); yield output; } } @@ -87,6 +106,9 @@ impl UserDefinedTableFunction { DataType::Int32, ); } + if output.column_at(0).as_int32().raw_iter().any(|i| i < 0) { + bail!("UDF returned negative row index"); + } if !output .column_at(1) .data_type() From d6c2dd000ae58a07adfc44a48f72582b3d11c036 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 18 Sep 2023 14:19:14 +0800 Subject: [PATCH 4/4] cargo fmt Signed-off-by: Runji Wang --- src/common/src/array/data_chunk.rs | 4 ++-- src/expr/src/table_function/user_defined.rs | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index 33a3b692415c2..0e29c599c64fe 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::borrow::Cow; -use std::fmt::Display; use std::hash::BuildHasher; use std::sync::Arc; use std::{fmt, usize}; @@ -265,7 +264,8 @@ impl DataChunk { /// Convert the chunk to compact format. /// - /// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to self. + /// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to + /// self. pub fn compact_cow(&self) -> Cow<'_, Self> { match &self.vis2 { Vis::Compact(_) => Cow::Borrowed(self), diff --git a/src/expr/src/table_function/user_defined.rs b/src/expr/src/table_function/user_defined.rs index 813cf23504482..510754a817c15 100644 --- a/src/expr/src/table_function/user_defined.rs +++ b/src/expr/src/table_function/user_defined.rs @@ -74,7 +74,8 @@ impl UserDefinedTableFunction { let output = DataChunk::try_from(&res?)?; self.check_output(&output)?; - // we send the compacted input to UDF, so we need to map the row indices back to the original input + // we send the compacted input to UDF, so we need to map the row indices back to the + // original input let origin_indices = output .column_at(0) .as_int32()