diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index d0e155f784546..2508d90f9632c 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -443,6 +443,14 @@ impl DataType { } d } + + /// Compares the datatype with another, ignoring nested field names and metadata. + pub fn equals_datatype(&self, other: &DataType) -> bool { + match (self, other) { + (Self::Struct(s1), Self::Struct(s2)) => s1.equals_datatype(s2), + _ => self == other, + } + } } impl From for PbDataType { diff --git a/src/common/src/types/struct_type.rs b/src/common/src/types/struct_type.rs index 340ffa7857c1d..239f506db8267 100644 --- a/src/common/src/types/struct_type.rs +++ b/src/common/src/types/struct_type.rs @@ -117,6 +117,16 @@ impl StructType { .chain(std::iter::repeat("").take(self.0.field_types.len() - self.0.field_names.len())) .zip_eq_debug(self.0.field_types.iter()) } + + /// Compares the datatype with another, ignoring nested field names and metadata. + pub fn equals_datatype(&self, other: &StructType) -> bool { + if self.0.field_types.len() != other.0.field_types.len() { + return false; + } + (self.0.field_types.iter()) + .zip_eq_fast(other.0.field_types.iter()) + .all(|(a, b)| a.equals_datatype(b)) + } } impl Display for StructType { diff --git a/src/expr/src/expr/expr_udf.rs b/src/expr/src/expr/expr_udf.rs index 3bb1583e26224..d0ecf58c3af98 100644 --- a/src/expr/src/expr/expr_udf.rs +++ b/src/expr/src/expr/expr_udf.rs @@ -103,6 +103,13 @@ impl UdfExpression { }; let mut array = ArrayImpl::try_from(arrow_array)?; array.set_bitmap(array.null_bitmap() & vis); + if !array.data_type().equals_datatype(&self.return_type) { + bail!( + "UDF returned {:?}, but expected {:?}", + array.data_type(), + self.return_type, + ); + } Ok(Arc::new(array)) } } diff --git a/src/expr/src/table_function/user_defined.rs b/src/expr/src/table_function/user_defined.rs index 21d3801cce47d..7b0385854c544 100644 --- a/src/expr/src/table_function/user_defined.rs +++ b/src/expr/src/table_function/user_defined.rs @@ -67,9 +67,39 @@ impl UserDefinedTableFunction { .await? { let output = DataChunk::try_from(&res?)?; + self.check_output(&output)?; yield output; } } + + /// Check if the output chunk is valid. + fn check_output(&self, output: &DataChunk) -> Result<()> { + if output.columns().len() != 2 { + bail!( + "UDF returned {} columns, but expected 2", + output.columns().len() + ); + } + if output.column_at(0).data_type() != DataType::Int32 { + bail!( + "UDF returned {:?} at column 0, but expected {:?}", + output.column_at(0).data_type(), + DataType::Int32, + ); + } + if !output + .column_at(1) + .data_type() + .equals_datatype(&self.return_type) + { + bail!( + "UDF returned {:?} at column 1, but expected {:?}", + output.column_at(1).data_type(), + &self.return_type, + ); + } + Ok(()) + } } #[cfg(not(madsim))]