Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add some comments / move some code #10912

Merged
merged 5 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//! This module defines utilities to work with system parameters ([`PbSystemParams`] in
//! `meta.proto`).
//!
//! To add a new system parameter:
//! - Add a new field to [`PbSystemParams`] in `meta.proto`.
//! - Add a new entry to [`for_all_undeprecated_params`] in this file.
//! - Add a new method to [`reader::SystemParamsReader`].

pub mod local_manager;
pub mod reader;

use std::fmt::Debug;
use std::ops::RangeBounds;

use paste::paste;
use risingwave_pb::meta::SystemParams;
use risingwave_pb::meta::PbSystemParams;

pub type SystemParamsError = String;

Expand Down Expand Up @@ -70,7 +78,7 @@ macro_rules! key_of {
};
}

/// Define key constants for fields in `SystemParams` for use of other modules.
/// Define key constants for fields in `PbSystemParams` for use of other modules.
macro_rules! def_key {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
paste! {
Expand Down Expand Up @@ -101,7 +109,7 @@ for_all_undeprecated_params!(def_default);
macro_rules! impl_check_missing_fields {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
/// Check if any undeprecated fields are missing.
pub fn check_missing_params(params: &SystemParams) -> Result<()> {
pub fn check_missing_params(params: &PbSystemParams) -> Result<()> {
$(
if params.$field.is_none() {
return Err(format!("missing system param {:?}", key_of!($field)));
Expand All @@ -118,7 +126,7 @@ macro_rules! impl_system_params_to_kv {
/// The returned map only contains undeprecated fields.
/// Return error if there are missing fields.
#[allow(clippy::vec_init_then_push)]
pub fn system_params_to_kv(params: &SystemParams) -> Result<Vec<(String, String)>> {
pub fn system_params_to_kv(params: &PbSystemParams) -> Result<Vec<(String, String)>> {
check_missing_params(params)?;
let mut ret = Vec::new();
$(ret.push((
Expand All @@ -132,7 +140,7 @@ macro_rules! impl_system_params_to_kv {

macro_rules! impl_derive_missing_fields {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
fn derive_missing_fields(params: &mut SystemParams) {
fn derive_missing_fields(params: &mut PbSystemParams) {
$(
if params.$field.is_none() && let Some(v) = OverrideFromParams::$field(params) {
params.$field = Some(v);
Expand All @@ -147,12 +155,12 @@ macro_rules! impl_system_params_from_kv {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
/// Try to deserialize deprecated fields as well.
/// Return error if there are unrecognized fields.
pub fn system_params_from_kv<K, V>(mut kvs: Vec<(K, V)>) -> Result<SystemParams>
pub fn system_params_from_kv<K, V>(mut kvs: Vec<(K, V)>) -> Result<PbSystemParams>
where
K: AsRef<[u8]> + Debug,
V: AsRef<[u8]> + Debug,
{
let mut ret = SystemParams::default();
let mut ret = PbSystemParams::default();
kvs.retain(|(k,v)| {
let k = std::str::from_utf8(k.as_ref()).unwrap();
let v = std::str::from_utf8(v.as_ref()).unwrap();
Expand Down Expand Up @@ -219,7 +227,7 @@ macro_rules! impl_default_validation_on_set {
///
/// ```ignore
/// impl FromParams for OverrideFromParams {
/// fn interval_ms(params: &SystemParams) -> Option<u64> {
/// fn interval_ms(params: &PbSystemParams) -> Option<u64> {
/// if let Some(sec) = params.interval_sec {
/// Some(sec * 1000)
/// } else {
Expand All @@ -234,7 +242,7 @@ macro_rules! impl_default_from_other_params {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
trait FromParams {
$(
fn $field(_params: &SystemParams) -> Option<$type> {
fn $field(_params: &PbSystemParams) -> Option<$type> {
None
}
)*
Expand All @@ -244,7 +252,7 @@ macro_rules! impl_default_from_other_params {

macro_rules! impl_set_system_param {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
pub fn set_system_param(params: &mut SystemParams, key: &str, value: Option<String>) -> Result<()> {
pub fn set_system_param(params: &mut PbSystemParams, key: &str, value: Option<String>) -> Result<()> {
match key {
$(
key_of!($field) => {
Expand Down Expand Up @@ -285,8 +293,8 @@ macro_rules! impl_is_mutable {
macro_rules! impl_system_params_for_test {
($({ $field:ident, $type:ty, $default:expr, $is_mutable:expr },)*) => {
#[allow(clippy::needless_update)]
pub fn system_params_for_test() -> SystemParams {
let mut ret = SystemParams {
pub fn system_params_for_test() -> PbSystemParams {
let mut ret = PbSystemParams {
$(
$field: $default,
)*
Expand Down Expand Up @@ -351,7 +359,7 @@ mod tests {
];

// To kv - missing field.
let p = SystemParams::default();
let p = PbSystemParams::default();
assert!(system_params_to_kv(&p).is_err());

// From kv - unrecognized field.
Expand All @@ -367,7 +375,7 @@ mod tests {

#[test]
fn test_set() {
let mut p = SystemParams::default();
let mut p = PbSystemParams::default();
// Unrecognized param.
assert!(set_system_param(&mut p, "?", Some("?".to_string())).is_err());
// Value out of range.
Expand Down
5 changes: 4 additions & 1 deletion src/prost/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.type_attribute(".", "#[derive(prost_helpers::AnyPB)]")
.type_attribute("node_body", "#[derive(::enum_as_inner::EnumAsInner)]")
.type_attribute("rex_node", "#[derive(::enum_as_inner::EnumAsInner)]")
// Eq + Hash are for plan nodes to do common sub-plan detection.
// The requirement is from Source node -> SourceCatalog -> WatermarkDesc -> expr
.type_attribute("catalog.WatermarkDesc", "#[derive(Eq, Hash)]")
.type_attribute("catalog.StreamSourceInfo", "#[derive(Eq, Hash)]")
.type_attribute("expr.ExprNode", "#[derive(Eq, Hash)]")
.type_attribute("data.DataType", "#[derive(Eq, Hash)]")
.type_attribute("expr.ExprNode.rex_node", "#[derive(Eq, Hash)]")
Expand All @@ -66,13 +69,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.type_attribute("data.Datum", "#[derive(Eq, Hash)]")
.type_attribute("expr.FunctionCall", "#[derive(Eq, Hash)]")
.type_attribute("expr.UserDefinedFunction", "#[derive(Eq, Hash)]")
.type_attribute("catalog.StreamSourceInfo", "#[derive(Eq, Hash)]")
.type_attribute(
"plan_common.ColumnDesc.generated_or_default_column",
"#[derive(Eq, Hash)]",
)
.type_attribute("plan_common.GeneratedColumnDesc", "#[derive(Eq, Hash)]")
.type_attribute("plan_common.DefaultColumnDesc", "#[derive(Eq, Hash)]")
// ===================
.out_dir(out_dir.as_path())
.compile(&protos, &[proto_dir.to_string()])
.expect("Failed to compile grpc!");
Expand Down
189 changes: 189 additions & 0 deletions src/udf/src/external.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow_array::RecordBatch;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::{FlightData, FlightDescriptor};
use arrow_schema::Schema;
use futures_util::{stream, Stream, StreamExt, TryStreamExt};
use tonic::transport::Channel;

/// Client for external function service based on Arrow Flight.
#[derive(Debug)]
pub struct ArrowFlightUdfClient {
client: FlightServiceClient<Channel>,
}

#[cfg(not(madsim))]
impl ArrowFlightUdfClient {
/// Connect to a UDF service.
pub async fn connect(addr: &str) -> Result<Self> {
let client = FlightServiceClient::connect(addr.to_string()).await?;
Ok(Self { client })
}

/// Check if the function is available and the schema is match.
pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> {
let descriptor = FlightDescriptor::new_path(vec![id.into()]);

let response = self.client.clone().get_flight_info(descriptor).await?;

// check schema
let info = response.into_inner();
let input_num = info.total_records as usize;
let full_schema = Schema::try_from(info)
.map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
let (input_fields, return_fields) = full_schema.fields.split_at(input_num);
let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect();
let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect();
let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect();
let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect();
if !data_types_match(&expect_input_types, &actual_input_types) {
return Err(Error::ArgumentMismatch {
function_id: id.into(),
expected: format!("{:?}", expect_input_types),
actual: format!("{:?}", actual_input_types),
});
}
if !data_types_match(&expect_result_types, &actual_result_types) {
return Err(Error::ReturnTypeMismatch {
function_id: id.into(),
expected: format!("{:?}", expect_result_types),
actual: format!("{:?}", actual_result_types),
});
}
Ok(())
}

/// Call a function.
pub async fn call(&self, id: &str, input: RecordBatch) -> Result<RecordBatch> {
let mut output_stream = self.call_stream(id, stream::once(async { input })).await?;
// TODO: support no output
let head = output_stream.next().await.ok_or(Error::NoReturned)??;
let mut remaining = vec![];
while let Some(batch) = output_stream.next().await {
remaining.push(batch?);
}
if remaining.is_empty() {
Ok(head)
} else {
Ok(arrow_select::concat::concat_batches(
&head.schema(),
std::iter::once(&head).chain(remaining.iter()),
)?)
}
}

/// Call a function with streaming input and output.
pub async fn call_stream(
&self,
id: &str,
inputs: impl Stream<Item = RecordBatch> + Send + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
let descriptor = FlightDescriptor::new_path(vec![id.into()]);
let flight_data_stream = FlightDataEncoderBuilder::new()
// XXX(wrj): unlimit the size of flight data to avoid splitting batch
// there's a bug in arrow-flight when splitting batch with list type array
// FIXME: remove this when the bug is fixed in arrow-flight
.with_max_flight_data_size(usize::MAX)
.build(inputs.map(Ok))
.map(move |res| FlightData {
// TODO: fill descriptor only for the first message
flight_descriptor: Some(descriptor.clone()),
..res.unwrap()
});

// call `do_exchange` on Flight server
let response = self.client.clone().do_exchange(flight_data_stream).await?;

// decode response
let stream = response.into_inner();
let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
// convert tonic::Status to FlightError
stream.map_err(|e| e.into()),
);
Ok(record_batch_stream.map_err(|e| e.into()))
}
}

// TODO: support UDF in simulation
#[cfg(madsim)]
impl ArrowFlightUdfClient {
/// Connect to a UDF service.
pub async fn connect(_addr: &str) -> Result<Self> {
panic!("UDF is not supported in simulation yet")
}

/// Check if the function is available.
pub async fn check(&self, _id: &str, _args: &Schema, _returns: &Schema) -> Result<()> {
panic!("UDF is not supported in simulation yet")
}

/// Call a function.
pub async fn call(&self, _id: &str, _input: RecordBatch) -> Result<RecordBatch> {
panic!("UDF is not supported in simulation yet")
}

/// Call a function with streaming input and output.
pub async fn call_stream(
&self,
_id: &str,
_inputs: impl Stream<Item = RecordBatch> + Send + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
panic!("UDF is not supported in simulation yet");
Ok(stream::empty())
}
}

pub type Result<T> = std::result::Result<T, Error>;

#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("failed to connect to UDF service: {0}")]
Connect(#[from] tonic::transport::Error),
#[error("failed to check UDF: {0}")]
Tonic(#[from] tonic::Status),
#[error("failed to call UDF: {0}")]
Flight(#[from] FlightError),
#[error("argument mismatch: function {function_id:?}, expected {expected}, actual {actual}")]
ArgumentMismatch {
function_id: String,
expected: String,
actual: String,
},
#[error(
"return type mismatch: function {function_id:?}, expected {expected}, actual {actual}"
)]
ReturnTypeMismatch {
function_id: String,
expected: String,
actual: String,
},
#[error("arrow error: {0}")]
Arrow(#[from] arrow_schema::ArrowError),
#[error("UDF service returned no data")]
NoReturned,
}

/// Check if two list of data types match, ignoring field names.
fn data_types_match(a: &[&arrow_schema::DataType], b: &[&arrow_schema::DataType]) -> bool {
if a.len() != b.len() {
return false;
}
#[allow(clippy::disallowed_methods)]
a.iter().zip(b.iter()).all(|(a, b)| a.equals_datatype(b))
}
Loading