Skip to content

Commit

Permalink
support the coexistence of variadic and fixed arguments
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <wangrunji0408@163.com>
  • Loading branch information
wangrunji0408 committed Sep 13, 2023
1 parent 850e5f4 commit 01c9519
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 88 deletions.
126 changes: 55 additions & 71 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,17 @@ impl FunctionAttr {
let name = self.name.clone();
let mut args = Vec::with_capacity(self.args.len());
for ty in &self.args {
if ty == "..." {
break;
}
args.push(data_type_name(ty));
}
let variadic = matches!(self.args.last(), Some(t) if t == "...");
let ret = data_type_name(&self.ret);

let pb_type = format_ident!("{}", utils::to_camel_case(&name));
let ctor_name = format_ident!("{}", self.ident_name());
let descriptor_type = quote! { crate::sig::func::FuncSign };
let variadic = self.args.len() == 1 && &self.args[0] == "...";
let build_fn = if build_fn {
let name = format_ident!("{}", user_fn.name);
quote! { #name }
Expand Down Expand Up @@ -101,8 +104,8 @@ impl FunctionAttr {
user_fn: &UserFunctionAttr,
optimize_const: bool,
) -> Result<TokenStream2> {
let num_args = self.args.len();
let variadic = self.args.len() == 1 && &self.args[0] == "...";
let variadic = matches!(self.args.last(), Some(t) if t == "...");
let num_args = self.args.len() - if variadic { 1 } else { 0 };
let fn_name = format_ident!("{}", user_fn.name);
let struct_name = match optimize_const {
true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
Expand Down Expand Up @@ -212,60 +215,40 @@ impl FunctionAttr {
quote! { () }
};

// ensure the number of children matches when arguments are fixed
// ensure the number of children matches the number of arguments
let check_children = match variadic {
true => quote! {},
true => quote! { crate::ensure!(children.len() >= #num_args); },
false => quote! { crate::ensure!(children.len() == #num_args); },
};

// evaluate child expressions and
// - build a chunk if arguments are variable
// - downcast arrays if arguments are fixed
let eval_children = if variadic {
// evaluate variadic arguments in `eval`
let eval_variadic = variadic.then(|| {
quote! {
let mut columns = Vec::with_capacity(self.children.len());
for child in &self.children {
let mut columns = Vec::with_capacity(self.children.len() - #num_args);
for child in &self.children[#num_args..] {
columns.push(child.eval_checked(input).await?);
}
let chunk = DataChunk::new(columns, input.vis().clone());
}
} else {
quote! {
#(
let #array_refs = self.children[#children_indices].eval_checked(input).await?;
let #arrays: &#arg_arrays = #array_refs.as_ref().into();
)*
let variadic_input = DataChunk::new(columns, input.vis().clone());
}
};
// evaluate child expressions and
// - build a row if arguments are variable
// - downcast scalars if arguments are fixed
let eval_row_children = if variadic {
});
// evaluate variadic arguments in `eval_row`
let eval_row_variadic = variadic.then(|| {
quote! {
let mut row = Vec::with_capacity(self.children.len());
for child in &self.children {
let mut row = Vec::with_capacity(self.children.len() - #num_args);
for child in &self.children[#num_args..] {
row.push(child.eval_row(input).await?);
}
let row = OwnedRow::new(row);
let variadic_row = OwnedRow::new(row);
}
} else {
quote! {
#(
let #datums = self.children[#children_indices].eval_row(input).await?;
let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
)*
}
};
});

let generic = if self.ret == "boolean" && user_fn.generic == 3 {
let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
// XXX: for generic compare functions, we need to specify the compatible type
let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
.parse::<TokenStream2>()
.unwrap();
quote! { ::<_, _, #compatible_type> }
} else {
quote! {}
};
});
let prebuilt_arg = match (&self.prebuild, optimize_const) {
// use the prebuilt argument
(Some(_), true) => quote! { &self.prebuilt_arg, },
Expand All @@ -274,22 +257,18 @@ impl FunctionAttr {
// no prebuilt argument
(None, _) => quote! {},
};
let context = match user_fn.context {
true => quote! { &self.context, },
false => quote! {},
};
let writer = match user_fn.write {
true => quote! { &mut writer, },
false => quote! {},
};
let variadic_args = variadic.then(|| quote! { variadic_row, });
let context = user_fn.context.then(|| quote! { &self.context, });
let writer = user_fn.write.then(|| quote! { &mut writer, });
// call the user defined function
// inputs: [ Option<impl ScalarRef> ]
let mut output = match variadic {
true => quote! { #fn_name(row, #context #writer) },
false => {
quote! { #fn_name #generic(#(#non_prebuilt_inputs,)* #prebuilt_arg #context #writer) }
}
};
let mut output = quote! { #fn_name #generic(
#(#non_prebuilt_inputs,)*
#prebuilt_arg
#variadic_args
#context
#writer
) };
// handle error if the function returns `Result`
// wrap a `Some` if the function doesn't return `Option`
output = match user_fn.return_type_kind {
Expand All @@ -300,7 +279,7 @@ impl FunctionAttr {
};
// if user function accepts non-option arguments, we assume the function
// returns null on null input, so we need to unwrap the inputs before calling.
if !variadic && !user_fn.arg_option {
if !user_fn.arg_option {
output = quote! {
match (#(#inputs,)*) {
(#(Some(#inputs),)*) => #output,
Expand Down Expand Up @@ -343,25 +322,17 @@ impl FunctionAttr {
};
// the main body in `eval`
let eval = if let Some(batch_fn) = &self.batch_fn {
assert!(!variadic, "customized batch function is not supported for variadic functions");
// user defined batch function
let fn_name = format_ident!("{}", batch_fn);
quote! {
let c = #fn_name(#(#arrays),*);
Ok(Arc::new(c.into()))
}
} else if variadic {
quote! {
let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());
for row in chunk.rows_with_holes() {
if let Some(row) = row {
#append_output
} else {
builder.append_null();
}
}
Ok(Arc::new(builder.finish().into()))
}
} else if (types::is_primitive(&self.ret) || self.ret == "boolean") && user_fn.is_pure() {
} else if (types::is_primitive(&self.ret) || self.ret == "boolean")
&& user_fn.is_pure()
&& !variadic
{
// SIMD optimization for primitive types
match self.args.len() {
0 => quote! {
Expand Down Expand Up @@ -397,23 +368,28 @@ impl FunctionAttr {
0 => quote! { std::iter::repeat(()).take(input.capacity()) },
_ => quote! { multizip((#(#arrays.iter(),)*)) },
};
let let_variadic = variadic.then(|| quote! {
let variadic_row = variadic_input.row_at_unchecked_vis(i);
});
quote! {
let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());

match input.vis() {
Vis::Bitmap(vis) => {
// allow using `zip` for performance
#[allow(clippy::disallowed_methods)]
for ((#(#inputs,)*), visible) in #array_zip.zip(vis.iter()) {
for (i, ((#(#inputs,)*), visible)) in #array_zip.zip(vis.iter()).enumerate() {
if !visible {
builder.append_null();
continue;
}
#let_variadic
#append_output
}
}
Vis::Compact(_) => {
for (#(#inputs,)*) in #array_zip {
for (i, (#(#inputs,)*)) in #array_zip.enumerate() {
#let_variadic
#append_output
}
}
Expand Down Expand Up @@ -456,11 +432,19 @@ impl FunctionAttr {
self.context.return_type.clone()
}
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
#eval_children
#(
let #array_refs = self.children[#children_indices].eval_checked(input).await?;
let #arrays: &#arg_arrays = #array_refs.as_ref().into();
)*
#eval_variadic
#eval
}
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
#eval_row_children
#(
let #datums = self.children[#children_indices].eval_row(input).await?;
let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
)*
#eval_row_variadic
Ok(#row_output)
}
}
Expand Down
12 changes: 3 additions & 9 deletions src/expr/src/vector_op/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,9 @@ use risingwave_expr_macro::function;

/// Concatenates all but the first argument, with separators. The first argument is used as the
/// separator string, and should not be NULL. Other NULL arguments are ignored.
#[function("concat_ws(...) -> varchar")]
fn concat_ws(row: impl Row, writer: &mut impl Write) -> Option<()> {
let sep = match row.datum_at(0) {
Some(sep) => sep.into_utf8(),
// return null if the separator is null
None => return None,
};

let mut string_iter = row.iter().skip(1).flatten();
#[function("concat_ws(varchar, ...) -> varchar")]
fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> {
let mut string_iter = vals.iter().flatten();
if let Some(string) = string_iter.next() {
string.write(writer).unwrap();
}
Expand Down
11 changes: 3 additions & 8 deletions src/expr/src/vector_op/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,8 @@ use crate::{ExprError, Result};

/// Formats arguments according to a format string.
// TODO(wrj): prebuild the formatter.
#[function("format(...) -> varchar")]
fn format(row: impl Row, writer: &mut impl Write) -> Result<Option<()>> {
let format_str = match row.datum_at(0) {
Some(format) => format.into_utf8(),
// return null if the format is null
None => return Ok(None),
};
#[function("format(varchar, ...) -> varchar")]
fn format(format_str: &str, row: impl Row, writer: &mut impl Write) -> Result<()> {
let formatter =
Formatter::from_str(format_str).map_err(|e| ExprError::Parse(e.to_string().into()))?;

Expand Down Expand Up @@ -63,7 +58,7 @@ fn format(row: impl Row, writer: &mut impl Write) -> Result<Option<()>> {
}
}
}
Ok(Some(()))
Ok(())
}

/// The type of format conversion to use to produce the format specifier's output.
Expand Down

0 comments on commit 01c9519

Please sign in to comment.