diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index 273be6d9f808a..004cc426dbe56 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -12,6 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! This module defines utilities to work with system parameters ([`PbSystemParams`] in +//! `meta.proto`). +//! +//! To add a new system parameter: +//! - Add a new field to [`PbSystemParams`] in `meta.proto`. +//! - Add a new entry to [`for_all_undeprecated_params`] in this file. +//! - Add a new method to [`reader::SystemParamsReader`]. + pub mod local_manager; pub mod reader; @@ -19,7 +27,7 @@ use std::fmt::Debug; use std::ops::RangeBounds; use paste::paste; -use risingwave_pb::meta::SystemParams; +use risingwave_pb::meta::PbSystemParams; pub type SystemParamsError = String; @@ -70,7 +78,7 @@ macro_rules! key_of { }; } -/// Define key constants for fields in `SystemParams` for use of other modules. +/// Define key constants for fields in `PbSystemParams` for use of other modules. macro_rules! def_key { ($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => { paste! { @@ -101,7 +109,7 @@ for_all_undeprecated_params!(def_default); macro_rules! impl_check_missing_fields { ($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => { /// Check if any undeprecated fields are missing. - pub fn check_missing_params(params: &SystemParams) -> Result<()> { + pub fn check_missing_params(params: &PbSystemParams) -> Result<()> { $( if params.$field.is_none() { return Err(format!("missing system param {:?}", key_of!($field))); @@ -118,7 +126,7 @@ macro_rules! impl_system_params_to_kv { /// The returned map only contains undeprecated fields. /// Return error if there are missing fields. #[allow(clippy::vec_init_then_push)] - pub fn system_params_to_kv(params: &SystemParams) -> Result> { + pub fn system_params_to_kv(params: &PbSystemParams) -> Result> { check_missing_params(params)?; let mut ret = Vec::new(); $(ret.push(( @@ -132,7 +140,7 @@ macro_rules! impl_system_params_to_kv { macro_rules! impl_derive_missing_fields { ($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => { - fn derive_missing_fields(params: &mut SystemParams) { + fn derive_missing_fields(params: &mut PbSystemParams) { $( if params.$field.is_none() && let Some(v) = OverrideFromParams::$field(params) { params.$field = Some(v); @@ -147,12 +155,12 @@ macro_rules! impl_system_params_from_kv { ($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => { /// Try to deserialize deprecated fields as well. /// Return error if there are unrecognized fields. - pub fn system_params_from_kv(mut kvs: Vec<(K, V)>) -> Result + pub fn system_params_from_kv(mut kvs: Vec<(K, V)>) -> Result where K: AsRef<[u8]> + Debug, V: AsRef<[u8]> + Debug, { - let mut ret = SystemParams::default(); + let mut ret = PbSystemParams::default(); kvs.retain(|(k,v)| { let k = std::str::from_utf8(k.as_ref()).unwrap(); let v = std::str::from_utf8(v.as_ref()).unwrap(); @@ -219,7 +227,7 @@ macro_rules! impl_default_validation_on_set { /// /// ```ignore /// impl FromParams for OverrideFromParams { -/// fn interval_ms(params: &SystemParams) -> Option { +/// fn interval_ms(params: &PbSystemParams) -> Option { /// if let Some(sec) = params.interval_sec { /// Some(sec * 1000) /// } else { @@ -234,7 +242,7 @@ macro_rules! impl_default_from_other_params { ($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => { trait FromParams { $( - fn $field(_params: &SystemParams) -> Option<$type> { + fn $field(_params: &PbSystemParams) -> Option<$type> { None } )* @@ -244,7 +252,7 @@ macro_rules! impl_default_from_other_params { macro_rules! impl_set_system_param { ($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => { - pub fn set_system_param(params: &mut SystemParams, key: &str, value: Option) -> Result<()> { + pub fn set_system_param(params: &mut PbSystemParams, key: &str, value: Option) -> Result<()> { match key { $( key_of!($field) => { @@ -285,8 +293,8 @@ macro_rules! impl_is_mutable { macro_rules! impl_system_params_for_test { ($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => { #[allow(clippy::needless_update)] - pub fn system_params_for_test() -> SystemParams { - let mut ret = SystemParams { + pub fn system_params_for_test() -> PbSystemParams { + let mut ret = PbSystemParams { $( $field: $default, )* @@ -351,7 +359,7 @@ mod tests { ]; // To kv - missing field. - let p = SystemParams::default(); + let p = PbSystemParams::default(); assert!(system_params_to_kv(&p).is_err()); // From kv - unrecognized field. @@ -367,7 +375,7 @@ mod tests { #[test] fn test_set() { - let mut p = SystemParams::default(); + let mut p = PbSystemParams::default(); // Unrecognized param. assert!(set_system_param(&mut p, "?", Some("?".to_string())).is_err()); // Value out of range. diff --git a/src/prost/build.rs b/src/prost/build.rs index fe09fd5a0c5a0..6802205ec5e41 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -57,7 +57,10 @@ fn main() -> Result<(), Box> { .type_attribute(".", "#[derive(prost_helpers::AnyPB)]") .type_attribute("node_body", "#[derive(::enum_as_inner::EnumAsInner)]") .type_attribute("rex_node", "#[derive(::enum_as_inner::EnumAsInner)]") + // Eq + Hash are for plan nodes to do common sub-plan detection. + // The requirement is from Source node -> SourceCatalog -> WatermarkDesc -> expr .type_attribute("catalog.WatermarkDesc", "#[derive(Eq, Hash)]") + .type_attribute("catalog.StreamSourceInfo", "#[derive(Eq, Hash)]") .type_attribute("expr.ExprNode", "#[derive(Eq, Hash)]") .type_attribute("data.DataType", "#[derive(Eq, Hash)]") .type_attribute("expr.ExprNode.rex_node", "#[derive(Eq, Hash)]") @@ -66,13 +69,13 @@ fn main() -> Result<(), Box> { .type_attribute("data.Datum", "#[derive(Eq, Hash)]") .type_attribute("expr.FunctionCall", "#[derive(Eq, Hash)]") .type_attribute("expr.UserDefinedFunction", "#[derive(Eq, Hash)]") - .type_attribute("catalog.StreamSourceInfo", "#[derive(Eq, Hash)]") .type_attribute( "plan_common.ColumnDesc.generated_or_default_column", "#[derive(Eq, Hash)]", ) .type_attribute("plan_common.GeneratedColumnDesc", "#[derive(Eq, Hash)]") .type_attribute("plan_common.DefaultColumnDesc", "#[derive(Eq, Hash)]") + // =================== .out_dir(out_dir.as_path()) .compile(&protos, &[proto_dir.to_string()]) .expect("Failed to compile grpc!"); diff --git a/src/udf/src/external.rs b/src/udf/src/external.rs new file mode 100644 index 0000000000000..44d94b2b4e125 --- /dev/null +++ b/src/udf/src/external.rs @@ -0,0 +1,189 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_array::RecordBatch; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::error::FlightError; +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::{FlightData, FlightDescriptor}; +use arrow_schema::Schema; +use futures_util::{stream, Stream, StreamExt, TryStreamExt}; +use tonic::transport::Channel; + +/// Client for external function service based on Arrow Flight. +#[derive(Debug)] +pub struct ArrowFlightUdfClient { + client: FlightServiceClient, +} + +#[cfg(not(madsim))] +impl ArrowFlightUdfClient { + /// Connect to a UDF service. + pub async fn connect(addr: &str) -> Result { + let client = FlightServiceClient::connect(addr.to_string()).await?; + Ok(Self { client }) + } + + /// Check if the function is available and the schema is match. + pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> { + let descriptor = FlightDescriptor::new_path(vec![id.into()]); + + let response = self.client.clone().get_flight_info(descriptor).await?; + + // check schema + let info = response.into_inner(); + let input_num = info.total_records as usize; + let full_schema = Schema::try_from(info) + .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; + let (input_fields, return_fields) = full_schema.fields.split_at(input_num); + let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect(); + let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect(); + let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect(); + let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect(); + if !data_types_match(&expect_input_types, &actual_input_types) { + return Err(Error::ArgumentMismatch { + function_id: id.into(), + expected: format!("{:?}", expect_input_types), + actual: format!("{:?}", actual_input_types), + }); + } + if !data_types_match(&expect_result_types, &actual_result_types) { + return Err(Error::ReturnTypeMismatch { + function_id: id.into(), + expected: format!("{:?}", expect_result_types), + actual: format!("{:?}", actual_result_types), + }); + } + Ok(()) + } + + /// Call a function. + pub async fn call(&self, id: &str, input: RecordBatch) -> Result { + let mut output_stream = self.call_stream(id, stream::once(async { input })).await?; + // TODO: support no output + let head = output_stream.next().await.ok_or(Error::NoReturned)??; + let mut remaining = vec![]; + while let Some(batch) = output_stream.next().await { + remaining.push(batch?); + } + if remaining.is_empty() { + Ok(head) + } else { + Ok(arrow_select::concat::concat_batches( + &head.schema(), + std::iter::once(&head).chain(remaining.iter()), + )?) + } + } + + /// Call a function with streaming input and output. + pub async fn call_stream( + &self, + id: &str, + inputs: impl Stream + Send + 'static, + ) -> Result> + Send + 'static> { + let descriptor = FlightDescriptor::new_path(vec![id.into()]); + let flight_data_stream = FlightDataEncoderBuilder::new() + // XXX(wrj): unlimit the size of flight data to avoid splitting batch + // there's a bug in arrow-flight when splitting batch with list type array + // FIXME: remove this when the bug is fixed in arrow-flight + .with_max_flight_data_size(usize::MAX) + .build(inputs.map(Ok)) + .map(move |res| FlightData { + // TODO: fill descriptor only for the first message + flight_descriptor: Some(descriptor.clone()), + ..res.unwrap() + }); + + // call `do_exchange` on Flight server + let response = self.client.clone().do_exchange(flight_data_stream).await?; + + // decode response + let stream = response.into_inner(); + let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( + // convert tonic::Status to FlightError + stream.map_err(|e| e.into()), + ); + Ok(record_batch_stream.map_err(|e| e.into())) + } +} + +// TODO: support UDF in simulation +#[cfg(madsim)] +impl ArrowFlightUdfClient { + /// Connect to a UDF service. + pub async fn connect(_addr: &str) -> Result { + panic!("UDF is not supported in simulation yet") + } + + /// Check if the function is available. + pub async fn check(&self, _id: &str, _args: &Schema, _returns: &Schema) -> Result<()> { + panic!("UDF is not supported in simulation yet") + } + + /// Call a function. + pub async fn call(&self, _id: &str, _input: RecordBatch) -> Result { + panic!("UDF is not supported in simulation yet") + } + + /// Call a function with streaming input and output. + pub async fn call_stream( + &self, + _id: &str, + _inputs: impl Stream + Send + 'static, + ) -> Result> + Send + 'static> { + panic!("UDF is not supported in simulation yet"); + Ok(stream::empty()) + } +} + +pub type Result = std::result::Result; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("failed to connect to UDF service: {0}")] + Connect(#[from] tonic::transport::Error), + #[error("failed to check UDF: {0}")] + Tonic(#[from] tonic::Status), + #[error("failed to call UDF: {0}")] + Flight(#[from] FlightError), + #[error("argument mismatch: function {function_id:?}, expected {expected}, actual {actual}")] + ArgumentMismatch { + function_id: String, + expected: String, + actual: String, + }, + #[error( + "return type mismatch: function {function_id:?}, expected {expected}, actual {actual}" + )] + ReturnTypeMismatch { + function_id: String, + expected: String, + actual: String, + }, + #[error("arrow error: {0}")] + Arrow(#[from] arrow_schema::ArrowError), + #[error("UDF service returned no data")] + NoReturned, +} + +/// Check if two list of data types match, ignoring field names. +fn data_types_match(a: &[&arrow_schema::DataType], b: &[&arrow_schema::DataType]) -> bool { + if a.len() != b.len() { + return false; + } + #[allow(clippy::disallowed_methods)] + a.iter().zip(b.iter()).all(|(a, b)| a.equals_datatype(b)) +} diff --git a/src/udf/src/lib.rs b/src/udf/src/lib.rs index 44d94b2b4e125..507aff3f093c3 100644 --- a/src/udf/src/lib.rs +++ b/src/udf/src/lib.rs @@ -12,178 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow_array::RecordBatch; -use arrow_flight::decode::FlightRecordBatchStream; -use arrow_flight::encode::FlightDataEncoderBuilder; -use arrow_flight::error::FlightError; -use arrow_flight::flight_service_client::FlightServiceClient; -use arrow_flight::{FlightData, FlightDescriptor}; -use arrow_schema::Schema; -use futures_util::{stream, Stream, StreamExt, TryStreamExt}; -use tonic::transport::Channel; - -/// Client for external function service based on Arrow Flight. -#[derive(Debug)] -pub struct ArrowFlightUdfClient { - client: FlightServiceClient, -} - -#[cfg(not(madsim))] -impl ArrowFlightUdfClient { - /// Connect to a UDF service. - pub async fn connect(addr: &str) -> Result { - let client = FlightServiceClient::connect(addr.to_string()).await?; - Ok(Self { client }) - } - - /// Check if the function is available and the schema is match. - pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> { - let descriptor = FlightDescriptor::new_path(vec![id.into()]); - - let response = self.client.clone().get_flight_info(descriptor).await?; - - // check schema - let info = response.into_inner(); - let input_num = info.total_records as usize; - let full_schema = Schema::try_from(info) - .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; - let (input_fields, return_fields) = full_schema.fields.split_at(input_num); - let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect(); - let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect(); - let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect(); - let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect(); - if !data_types_match(&expect_input_types, &actual_input_types) { - return Err(Error::ArgumentMismatch { - function_id: id.into(), - expected: format!("{:?}", expect_input_types), - actual: format!("{:?}", actual_input_types), - }); - } - if !data_types_match(&expect_result_types, &actual_result_types) { - return Err(Error::ReturnTypeMismatch { - function_id: id.into(), - expected: format!("{:?}", expect_result_types), - actual: format!("{:?}", actual_result_types), - }); - } - Ok(()) - } - - /// Call a function. - pub async fn call(&self, id: &str, input: RecordBatch) -> Result { - let mut output_stream = self.call_stream(id, stream::once(async { input })).await?; - // TODO: support no output - let head = output_stream.next().await.ok_or(Error::NoReturned)??; - let mut remaining = vec![]; - while let Some(batch) = output_stream.next().await { - remaining.push(batch?); - } - if remaining.is_empty() { - Ok(head) - } else { - Ok(arrow_select::concat::concat_batches( - &head.schema(), - std::iter::once(&head).chain(remaining.iter()), - )?) - } - } - - /// Call a function with streaming input and output. - pub async fn call_stream( - &self, - id: &str, - inputs: impl Stream + Send + 'static, - ) -> Result> + Send + 'static> { - let descriptor = FlightDescriptor::new_path(vec![id.into()]); - let flight_data_stream = FlightDataEncoderBuilder::new() - // XXX(wrj): unlimit the size of flight data to avoid splitting batch - // there's a bug in arrow-flight when splitting batch with list type array - // FIXME: remove this when the bug is fixed in arrow-flight - .with_max_flight_data_size(usize::MAX) - .build(inputs.map(Ok)) - .map(move |res| FlightData { - // TODO: fill descriptor only for the first message - flight_descriptor: Some(descriptor.clone()), - ..res.unwrap() - }); - - // call `do_exchange` on Flight server - let response = self.client.clone().do_exchange(flight_data_stream).await?; - - // decode response - let stream = response.into_inner(); - let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( - // convert tonic::Status to FlightError - stream.map_err(|e| e.into()), - ); - Ok(record_batch_stream.map_err(|e| e.into())) - } -} - -// TODO: support UDF in simulation -#[cfg(madsim)] -impl ArrowFlightUdfClient { - /// Connect to a UDF service. - pub async fn connect(_addr: &str) -> Result { - panic!("UDF is not supported in simulation yet") - } - - /// Check if the function is available. - pub async fn check(&self, _id: &str, _args: &Schema, _returns: &Schema) -> Result<()> { - panic!("UDF is not supported in simulation yet") - } - - /// Call a function. - pub async fn call(&self, _id: &str, _input: RecordBatch) -> Result { - panic!("UDF is not supported in simulation yet") - } - - /// Call a function with streaming input and output. - pub async fn call_stream( - &self, - _id: &str, - _inputs: impl Stream + Send + 'static, - ) -> Result> + Send + 'static> { - panic!("UDF is not supported in simulation yet"); - Ok(stream::empty()) - } -} - -pub type Result = std::result::Result; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("failed to connect to UDF service: {0}")] - Connect(#[from] tonic::transport::Error), - #[error("failed to check UDF: {0}")] - Tonic(#[from] tonic::Status), - #[error("failed to call UDF: {0}")] - Flight(#[from] FlightError), - #[error("argument mismatch: function {function_id:?}, expected {expected}, actual {actual}")] - ArgumentMismatch { - function_id: String, - expected: String, - actual: String, - }, - #[error( - "return type mismatch: function {function_id:?}, expected {expected}, actual {actual}" - )] - ReturnTypeMismatch { - function_id: String, - expected: String, - actual: String, - }, - #[error("arrow error: {0}")] - Arrow(#[from] arrow_schema::ArrowError), - #[error("UDF service returned no data")] - NoReturned, -} - -/// Check if two list of data types match, ignoring field names. -fn data_types_match(a: &[&arrow_schema::DataType], b: &[&arrow_schema::DataType]) -> bool { - if a.len() != b.len() { - return false; - } - #[allow(clippy::disallowed_methods)] - a.iter().zip(b.iter()).all(|(a, b)| a.equals_datatype(b)) -} +mod external; +pub use external::{ArrowFlightUdfClient, Error};