Skip to content

Commit

Permalink
chore: add some comments / move some code
Browse files Browse the repository at this point in the history
picked from #10910
  • Loading branch information
xxchan committed Jul 12, 2023
1 parent ed5b9e0 commit 3fc615d
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 190 deletions.
36 changes: 22 additions & 14 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//! This module defines utilities to work with system parameters ([`PbSystemParams`] in
//! `meta.proto`).
//!
//! To add a new system parameter:
//! - Add a new field to [`PbSystemParams`] in `meta.proto`.
//! - Add a new entry to [`for_all_undeprecated_params`] in this file.
//! - Add a new method to [`reader::SystemParamsReader`].
pub mod local_manager;
pub mod reader;

use std::fmt::Debug;
use std::ops::RangeBounds;

use paste::paste;
use risingwave_pb::meta::SystemParams;
use risingwave_pb::meta::PbSystemParams;

pub type SystemParamsError = String;

Expand Down Expand Up @@ -70,7 +78,7 @@ macro_rules! key_of {
};
}

/// Define key constants for fields in `SystemParams` for use of other modules.
/// Define key constants for fields in `PbSystemParams` for use of other modules.
macro_rules! def_key {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
paste! {
Expand Down Expand Up @@ -101,7 +109,7 @@ for_all_undeprecated_params!(def_default);
macro_rules! impl_check_missing_fields {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
/// Check if any undeprecated fields are missing.
pub fn check_missing_params(params: &SystemParams) -> Result<()> {
pub fn check_missing_params(params: &PbSystemParams) -> Result<()> {
$(
if params.$field.is_none() {
return Err(format!("missing system param {:?}", key_of!($field)));
Expand All @@ -118,7 +126,7 @@ macro_rules! impl_system_params_to_kv {
/// The returned map only contains undeprecated fields.
/// Return error if there are missing fields.
#[allow(clippy::vec_init_then_push)]
pub fn system_params_to_kv(params: &SystemParams) -> Result<Vec<(String, String)>> {
pub fn system_params_to_kv(params: &PbSystemParams) -> Result<Vec<(String, String)>> {
check_missing_params(params)?;
let mut ret = Vec::new();
$(ret.push((
Expand All @@ -132,7 +140,7 @@ macro_rules! impl_system_params_to_kv {

macro_rules! impl_derive_missing_fields {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
fn derive_missing_fields(params: &mut SystemParams) {
fn derive_missing_fields(params: &mut PbSystemParams) {
$(
if params.$field.is_none() && let Some(v) = OverrideFromParams::$field(params) {
params.$field = Some(v);
Expand All @@ -147,12 +155,12 @@ macro_rules! impl_system_params_from_kv {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
/// Try to deserialize deprecated fields as well.
/// Return error if there are unrecognized fields.
pub fn system_params_from_kv<K, V>(mut kvs: Vec<(K, V)>) -> Result<SystemParams>
pub fn system_params_from_kv<K, V>(mut kvs: Vec<(K, V)>) -> Result<PbSystemParams>
where
K: AsRef<[u8]> + Debug,
V: AsRef<[u8]> + Debug,
{
let mut ret = SystemParams::default();
let mut ret = PbSystemParams::default();
kvs.retain(|(k,v)| {
let k = std::str::from_utf8(k.as_ref()).unwrap();
let v = std::str::from_utf8(v.as_ref()).unwrap();
Expand Down Expand Up @@ -219,7 +227,7 @@ macro_rules! impl_default_validation_on_set {
///
/// ```ignore
/// impl FromParams for OverrideFromParams {
/// fn interval_ms(params: &SystemParams) -> Option<u64> {
/// fn interval_ms(params: &PbSystemParams) -> Option<u64> {
/// if let Some(sec) = params.interval_sec {
/// Some(sec * 1000)
/// } else {
Expand All @@ -234,7 +242,7 @@ macro_rules! impl_default_from_other_params {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
trait FromParams {
$(
fn $field(_params: &SystemParams) -> Option<$type> {
fn $field(_params: &PbSystemParams) -> Option<$type> {
None
}
)*
Expand All @@ -244,7 +252,7 @@ macro_rules! impl_default_from_other_params {

macro_rules! impl_set_system_param {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
pub fn set_system_param(params: &mut SystemParams, key: &str, value: Option<String>) -> Result<()> {
pub fn set_system_param(params: &mut PbSystemParams, key: &str, value: Option<String>) -> Result<()> {
match key {
$(
key_of!($field) => {
Expand Down Expand Up @@ -285,8 +293,8 @@ macro_rules! impl_is_mutable {
macro_rules! impl_system_params_for_test {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
#[allow(clippy::needless_update)]
pub fn system_params_for_test() -> SystemParams {
let mut ret = SystemParams {
pub fn system_params_for_test() -> PbSystemParams {
let mut ret = PbSystemParams {
$(
$field: $default,
)*
Expand Down Expand Up @@ -351,7 +359,7 @@ mod tests {
];

// To kv - missing field.
let p = SystemParams::default();
let p = PbSystemParams::default();
assert!(system_params_to_kv(&p).is_err());

// From kv - unrecognized field.
Expand All @@ -367,7 +375,7 @@ mod tests {

#[test]
fn test_set() {
let mut p = SystemParams::default();
let mut p = PbSystemParams::default();
// Unrecognized param.
assert!(set_system_param(&mut p, "?", Some("?".to_string())).is_err());
// Value out of range.
Expand Down
5 changes: 4 additions & 1 deletion src/prost/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.type_attribute(".", "#[derive(prost_helpers::AnyPB)]")
.type_attribute("node_body", "#[derive(::enum_as_inner::EnumAsInner)]")
.type_attribute("rex_node", "#[derive(::enum_as_inner::EnumAsInner)]")
// Eq + Hash are for plan nodes to do common sub-plan detection.
// The requirement is from Source node -> SourceCatalog -> WatermarkDesc -> expr
.type_attribute("catalog.WatermarkDesc", "#[derive(Eq, Hash)]")
.type_attribute("catalog.StreamSourceInfo", "#[derive(Eq, Hash)]")
.type_attribute("expr.ExprNode", "#[derive(Eq, Hash)]")
.type_attribute("data.DataType", "#[derive(Eq, Hash)]")
.type_attribute("expr.ExprNode.rex_node", "#[derive(Eq, Hash)]")
Expand All @@ -66,13 +69,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.type_attribute("data.Datum", "#[derive(Eq, Hash)]")
.type_attribute("expr.FunctionCall", "#[derive(Eq, Hash)]")
.type_attribute("expr.UserDefinedFunction", "#[derive(Eq, Hash)]")
.type_attribute("catalog.StreamSourceInfo", "#[derive(Eq, Hash)]")
.type_attribute(
"plan_common.ColumnDesc.generated_or_default_column",
"#[derive(Eq, Hash)]",
)
.type_attribute("plan_common.GeneratedColumnDesc", "#[derive(Eq, Hash)]")
.type_attribute("plan_common.DefaultColumnDesc", "#[derive(Eq, Hash)]")
// ===================
.out_dir(out_dir.as_path())
.compile(&protos, &[proto_dir.to_string()])
.expect("Failed to compile grpc!");
Expand Down
189 changes: 189 additions & 0 deletions src/udf/src/external.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow_array::RecordBatch;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::{FlightData, FlightDescriptor};
use arrow_schema::Schema;
use futures_util::{stream, Stream, StreamExt, TryStreamExt};
use tonic::transport::Channel;

/// Client for external function service based on Arrow Flight.
#[derive(Debug)]
pub struct ArrowFlightUdfClient {
client: FlightServiceClient<Channel>,
}

#[cfg(not(madsim))]
impl ArrowFlightUdfClient {
/// Connect to a UDF service.
pub async fn connect(addr: &str) -> Result<Self> {
let client = FlightServiceClient::connect(addr.to_string()).await?;
Ok(Self { client })
}

/// Check if the function is available and the schema is match.
pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> {
let descriptor = FlightDescriptor::new_path(vec![id.into()]);

let response = self.client.clone().get_flight_info(descriptor).await?;

// check schema
let info = response.into_inner();
let input_num = info.total_records as usize;
let full_schema = Schema::try_from(info)
.map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
let (input_fields, return_fields) = full_schema.fields.split_at(input_num);
let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect();
let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect();
let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect();
let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect();
if !data_types_match(&expect_input_types, &actual_input_types) {
return Err(Error::ArgumentMismatch {
function_id: id.into(),
expected: format!("{:?}", expect_input_types),
actual: format!("{:?}", actual_input_types),
});
}
if !data_types_match(&expect_result_types, &actual_result_types) {
return Err(Error::ReturnTypeMismatch {
function_id: id.into(),
expected: format!("{:?}", expect_result_types),
actual: format!("{:?}", actual_result_types),
});
}
Ok(())
}

/// Call a function.
pub async fn call(&self, id: &str, input: RecordBatch) -> Result<RecordBatch> {
let mut output_stream = self.call_stream(id, stream::once(async { input })).await?;
// TODO: support no output
let head = output_stream.next().await.ok_or(Error::NoReturned)??;
let mut remaining = vec![];
while let Some(batch) = output_stream.next().await {
remaining.push(batch?);
}
if remaining.is_empty() {
Ok(head)
} else {
Ok(arrow_select::concat::concat_batches(
&head.schema(),
std::iter::once(&head).chain(remaining.iter()),
)?)
}
}

/// Call a function with streaming input and output.
pub async fn call_stream(
&self,
id: &str,
inputs: impl Stream<Item = RecordBatch> + Send + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
let descriptor = FlightDescriptor::new_path(vec![id.into()]);
let flight_data_stream = FlightDataEncoderBuilder::new()
// XXX(wrj): unlimit the size of flight data to avoid splitting batch
// there's a bug in arrow-flight when splitting batch with list type array
// FIXME: remove this when the bug is fixed in arrow-flight
.with_max_flight_data_size(usize::MAX)
.build(inputs.map(Ok))
.map(move |res| FlightData {
// TODO: fill descriptor only for the first message
flight_descriptor: Some(descriptor.clone()),
..res.unwrap()
});

// call `do_exchange` on Flight server
let response = self.client.clone().do_exchange(flight_data_stream).await?;

// decode response
let stream = response.into_inner();
let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
// convert tonic::Status to FlightError
stream.map_err(|e| e.into()),
);
Ok(record_batch_stream.map_err(|e| e.into()))
}
}

// TODO: support UDF in simulation
#[cfg(madsim)]
impl ArrowFlightUdfClient {
/// Connect to a UDF service.
pub async fn connect(_addr: &str) -> Result<Self> {
panic!("UDF is not supported in simulation yet")
}

/// Check if the function is available.
pub async fn check(&self, _id: &str, _args: &Schema, _returns: &Schema) -> Result<()> {
panic!("UDF is not supported in simulation yet")
}

/// Call a function.
pub async fn call(&self, _id: &str, _input: RecordBatch) -> Result<RecordBatch> {
panic!("UDF is not supported in simulation yet")
}

/// Call a function with streaming input and output.
pub async fn call_stream(
&self,
_id: &str,
_inputs: impl Stream<Item = RecordBatch> + Send + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
panic!("UDF is not supported in simulation yet");
Ok(stream::empty())
}
}

pub type Result<T> = std::result::Result<T, Error>;

#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("failed to connect to UDF service: {0}")]
Connect(#[from] tonic::transport::Error),
#[error("failed to check UDF: {0}")]
Tonic(#[from] tonic::Status),
#[error("failed to call UDF: {0}")]
Flight(#[from] FlightError),
#[error("argument mismatch: function {function_id:?}, expected {expected}, actual {actual}")]
ArgumentMismatch {
function_id: String,
expected: String,
actual: String,
},
#[error(
"return type mismatch: function {function_id:?}, expected {expected}, actual {actual}"
)]
ReturnTypeMismatch {
function_id: String,
expected: String,
actual: String,
},
#[error("arrow error: {0}")]
Arrow(#[from] arrow_schema::ArrowError),
#[error("UDF service returned no data")]
NoReturned,
}

/// Check if two list of data types match, ignoring field names.
fn data_types_match(a: &[&arrow_schema::DataType], b: &[&arrow_schema::DataType]) -> bool {
if a.len() != b.len() {
return false;
}
#[allow(clippy::disallowed_methods)]
a.iter().zip(b.iter()).all(|(a, b)| a.equals_datatype(b))
}
Loading

0 comments on commit 3fc615d

Please sign in to comment.