Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherrypick fix(udf): handle visibility of input chunks in UDTF (#12357) #12368

Merged
merged 5 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,35 @@ select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d9
----
192.168.0.14 192.168.0.1 861 8374

# steaming
# to ensure UDF & UDTF respect visibility

statement ok
create table t (x int);

statement ok
create materialized view mv as select gcd(x, x), series(x) from t where x <> 2;

statement ok
insert into t values (1), (2), (3);

statement ok
flush;

query II
select * from mv;
----
1 0
3 0
3 1
3 2

statement ok
drop materialized view mv;

statement ok
drop table t;

# error handling

statement error
Expand Down
6 changes: 4 additions & 2 deletions src/common/src/array/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::util::iter_util::ZipEqDebug;

// Implement bi-directional `From` between `DataChunk` and `arrow_array::RecordBatch`.

// note: DataChunk -> arrow RecordBatch will IGNORE the visibilities.
impl TryFrom<&DataChunk> for arrow_array::RecordBatch {
type Error = ArrayError;

Expand All @@ -47,8 +48,9 @@ impl TryFrom<&DataChunk> for arrow_array::RecordBatch {
.collect();

let schema = Arc::new(Schema::new(fields));

arrow_array::RecordBatch::try_new(schema, columns)
let opts =
arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity()));
arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts)
.map_err(|err| ArrayError::ToArrow(err.to_string()))
}
}
Expand Down
23 changes: 23 additions & 0 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::borrow::Cow;
use std::hash::BuildHasher;
use std::sync::Arc;
use std::{fmt, usize};
Expand Down Expand Up @@ -261,6 +262,28 @@ impl DataChunk {
}
}

/// Convert the chunk to compact format.
///
/// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to
/// self.
pub fn compact_cow(&self) -> Cow<'_, Self> {
match &self.vis2 {
Vis::Compact(_) => Cow::Borrowed(self),
Vis::Bitmap(visibility) => {
let cardinality = visibility.count_ones();
let columns = self
.columns
.iter()
.map(|col| {
let array = col;
array.compact(visibility, cardinality).into()
})
.collect::<Vec<_>>();
Cow::Owned(Self::new(columns, cardinality))
}
}
}

pub fn from_protobuf(proto: &PbDataChunk) -> ArrayResult<Self> {
let mut columns = vec![];
for any_col in proto.get_columns() {
Expand Down
8 changes: 8 additions & 0 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,14 @@ impl DataType {
}
d
}

/// Compares the datatype with another, ignoring nested field names and metadata.
pub fn equals_datatype(&self, other: &DataType) -> bool {
match (self, other) {
(Self::Struct(s1), Self::Struct(s2)) => s1.equals_datatype(s2),
_ => self == other,
}
}
}

impl From<DataType> for PbDataType {
Expand Down
10 changes: 10 additions & 0 deletions src/common/src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ impl StructType {
.chain(std::iter::repeat("").take(self.0.field_types.len() - self.0.field_names.len()))
.zip_eq_debug(self.0.field_types.iter())
}

/// Compares the datatype with another, ignoring nested field names and metadata.
pub fn equals_datatype(&self, other: &StructType) -> bool {
if self.0.field_types.len() != other.0.field_types.len() {
return false;
}
(self.0.field_types.iter())
.zip_eq_fast(other.0.field_types.iter())
.all(|(a, b)| a.equals_datatype(b))
}
}

impl Display for StructType {
Expand Down
7 changes: 7 additions & 0 deletions src/expr/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ impl UdfExpression {
};
let mut array = ArrayImpl::try_from(arrow_array)?;
array.set_bitmap(array.null_bitmap() & vis);
if !array.data_type().equals_datatype(&self.return_type) {
bail!(
"UDF returned {:?}, but expected {:?}",
array.data_type(),
self.return_type,
);
}
Ok(Arc::new(array))
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/table_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send {
/// # Contract of the output
///
/// The returned `DataChunk` contains exact two columns:
/// - The first column is an I32Array containing row indexes of input chunk. It should be
/// - The first column is an I32Array containing row indices of input chunk. It should be
/// monotonically increasing.
/// - The second column is the output values. The data type of the column is `return_type`.
///
Expand Down Expand Up @@ -80,7 +80,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send {
/// (You don't need to understand this section to implement a `TableFunction`)
///
/// The output of the `TableFunction` is different from the output of the `ProjectSet` executor.
/// `ProjectSet` executor uses the row indexes to stitch multiple table functions and produces
/// `ProjectSet` executor uses the row indices to stitch multiple table functions and produces
/// `projected_row_id`.
///
/// ## Example
Expand Down
69 changes: 61 additions & 8 deletions src/expr/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

use std::sync::Arc;

use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use futures_util::stream;
use risingwave_common::array::DataChunk;
use risingwave_common::array::{DataChunk, I32Array};
use risingwave_common::bail;
use risingwave_udf::ArrowFlightUdfClient;

Expand All @@ -25,6 +26,7 @@ use super::*;
#[derive(Debug)]
pub struct UserDefinedTableFunction {
children: Vec<BoxedExpression>,
#[allow(dead_code)]
arg_schema: SchemaRef,
return_type: DataType,
client: Arc<ArrowFlightUdfClient>,
Expand All @@ -49,27 +51,78 @@ impl TableFunction for UserDefinedTableFunction {
impl UserDefinedTableFunction {
#[try_stream(boxed, ok = DataChunk, error = ExprError)]
async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
// evaluate children expressions
let mut columns = Vec::with_capacity(self.children.len());
for c in &self.children {
let val = c.eval_checked(input).await?.as_ref().try_into()?;
let val = c.eval_checked(input).await?;
columns.push(val);
}
let direct_input = DataChunk::new(columns, input.vis().clone());

// compact the input chunk and record the row mapping
let visible_rows = direct_input.vis().iter_ones().collect_vec();
let compacted_input = direct_input.compact_cow();
let arrow_input = RecordBatch::try_from(compacted_input.as_ref())?;

let opts =
arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality()));
let input =
arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts)
.expect("failed to build record batch");
// call UDTF
#[for_await]
for res in self
.client
.call_stream(&self.identifier, stream::once(async { input }))
.call_stream(&self.identifier, stream::once(async { arrow_input }))
.await?
{
let output = DataChunk::try_from(&res?)?;
self.check_output(&output)?;

// we send the compacted input to UDF, so we need to map the row indices back to the
// original input
let origin_indices = output
.column_at(0)
.as_int32()
.raw_iter()
// we have checked all indices are non-negative
.map(|idx| visible_rows[idx as usize] as i32)
.collect::<I32Array>();

let output = DataChunk::new(
vec![origin_indices.into_ref(), output.column_at(1).clone()],
output.vis().clone(),
);
yield output;
}
}

/// Check if the output chunk is valid.
fn check_output(&self, output: &DataChunk) -> Result<()> {
if output.columns().len() != 2 {
bail!(
"UDF returned {} columns, but expected 2",
output.columns().len()
);
}
if output.column_at(0).data_type() != DataType::Int32 {
bail!(
"UDF returned {:?} at column 0, but expected {:?}",
output.column_at(0).data_type(),
DataType::Int32,
);
}
if output.column_at(0).as_int32().raw_iter().any(|i| i < 0) {
bail!("UDF returned negative row index");
}
if !output
.column_at(1)
.data_type()
.equals_datatype(&self.return_type)
{
bail!(
"UDF returned {:?} at column 1, but expected {:?}",
output.column_at(1).data_type(),
&self.return_type,
);
}
Ok(())
}
}

#[cfg(not(madsim))]
Expand Down
3 changes: 3 additions & 0 deletions src/udf/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ pub enum Error {

#[error("UDF service returned no data")]
NoReturned,

#[error("Flight service error: {0}")]
ServiceError(String),
}

static_assertions::const_assert_eq!(std::mem::size_of::<Error>(), 32);
Expand Down
9 changes: 9 additions & 0 deletions src/udf/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ impl ArrowFlightUdfClient {
let input_num = info.total_records as usize;
let full_schema = Schema::try_from(info)
.map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
if input_num > full_schema.fields.len() {
return Err(Error::ServiceError(format!(
"function {:?} schema info not consistency: input_num: {}, total_fields: {}",
id,
input_num,
full_schema.fields.len()
)));
}

let (input_fields, return_fields) = full_schema.fields.split_at(input_num);
let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect();
let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect();
Expand Down
Loading