Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: scalar udf function argument handling #2304

Merged
merged 13 commits into from
Dec 28, 2023
35 changes: 34 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/sqlbuiltins/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ bson = "2.7.0"
tokio-util = "0.7.10"
bytes = "1.5.0"
kdl = "5.0.0-alpha.1"
memoize = { version = "0.4.2", features = ["full"] }
53 changes: 51 additions & 2 deletions crates/sqlbuiltins/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,56 @@
#[derive(Debug, thiserror::Error)]
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::error::ArrowError;
use datafusion::error::DataFusionError;

#[derive(Clone, Debug, thiserror::Error)]
pub enum BuiltinError {
#[error("parse error: {0}")]
ParseError(String),

#[error("fundamental parsing error")]
FundamentalError,

#[error("missing value at index {0}")]
MissingValueAtIndex(usize),

#[error("expected value missing")]
MissingValue,

#[error("invalid value: {0}")]
InvalidValue(String),

#[error("columnar values not support at index {0}")]
InvalidColumnarValue(usize),

#[error("value was type {0}, expected {1}")]
IncorrectType(DataType, DataType),

#[error(transparent)]
DatafusionExtError(#[from] datafusion_ext::errors::ExtensionError),
KdlError(#[from] kdl::KdlError),

#[error("DataFusionError: {0}")]
DataFusionError(String),

#[error("ArrowError: {0}")]
ArrowError(String),
}

pub type Result<T, E = BuiltinError> = std::result::Result<T, E>;

impl From<BuiltinError> for DataFusionError {
fn from(e: BuiltinError) -> Self {
DataFusionError::Execution(e.to_string())
}
}

impl From<DataFusionError> for BuiltinError {
fn from(e: DataFusionError) -> Self {
BuiltinError::DataFusionError(e.to_string())
}
}

impl From<ArrowError> for BuiltinError {
fn from(e: ArrowError) -> Self {
BuiltinError::ArrowError(e.to_string())
}
}
140 changes: 55 additions & 85 deletions crates/sqlbuiltins/src/functions/scalars/kdl.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::*;
use ::kdl::{KdlDocument, KdlNode, KdlQuery};
use memoize::memoize;

use super::*;

#[derive(Clone)]
pub struct KDLSelect;

impl ConstBuiltinFunction for KDLSelect {
Expand Down Expand Up @@ -31,27 +32,35 @@ impl BuiltinScalarUDF for KDLSelect {
signature: ConstBuiltinFunction::signature(self).unwrap(),
return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))),
fun: Arc::new(move |input| {
let (sdoc, filter) = kdl_parse_udf_args(input)?;

let out: Vec<&KdlNode> = sdoc
.query_all(filter)
.map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string()))
.map(|iter| iter.collect())?;

let mut doc = sdoc.clone();
let elems = doc.nodes_mut();
elems.clear();
for item in &out {
elems.push(item.to_owned().clone())
}

// TODO: consider if we should always return LargeUtf8?
// could end up with truncation (or an error) the document
// is too long and we write the data to a table that is
// established (and mostly) shorter values.
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
doc.to_string(),
))))
let filter = get_nth_string_fn_arg(input, 1)?;

get_nth_string_value(
input,
0,
&|value: String| -> Result<ScalarValue, BuiltinError> {
let sdoc: kdl::KdlDocument =
value.parse().map_err(BuiltinError::KdlError)?;

let out: Vec<&KdlNode> = sdoc
.query_all(compile_kdl_query(filter.clone())?)
.map_err(BuiltinError::KdlError)
.map(|iter| iter.collect())?;

let mut doc = sdoc.clone();
let elems = doc.nodes_mut();
elems.clear();
for item in &out {
elems.push(item.to_owned().clone())
}

// TODO: consider if we should always return LargeUtf8?
// could end up with truncation (or an error) the document
// is too long and we write the data to a table that is
// established (and mostly) shorter values.
Ok(ScalarValue::Utf8(Some(doc.to_string())))
},
)
.map_err(DataFusionError::from)
}),
};
Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new(
Expand All @@ -61,8 +70,8 @@ impl BuiltinScalarUDF for KDLSelect {
}
}

#[derive(Clone)]
pub struct KDLMatches;

impl ConstBuiltinFunction for KDLMatches {
const NAME: &'static str = "kdl_matches";
const DESCRIPTION: &'static str =
Expand All @@ -83,28 +92,31 @@ impl ConstBuiltinFunction for KDLMatches {
))
}
}

impl BuiltinScalarUDF for KDLMatches {
fn as_expr(&self, args: Vec<Expr>) -> Expr {
let udf = ScalarUDF {
name: "kdl_matches".to_string(),
signature: Signature::new(
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
]),
Volatility::Immutable,
),
name: Self::NAME.to_string(),
signature: ConstBuiltinFunction::signature(self).unwrap(),
return_type: Arc::new(|_| Ok(Arc::new(DataType::Boolean))),
fun: Arc::new(move |input| {
let (doc, filter) = kdl_parse_udf_args(input)?;
let filter = get_nth_string_fn_arg(input, 1)?;

get_nth_string_value(
input,
0,
&|value: String| -> Result<ScalarValue, BuiltinError> {
let doc: kdl::KdlDocument =
value.parse().map_err(BuiltinError::KdlError)?;

Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
doc.query(filter)
.map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string()))
.map(|val| val.is_some())?,
))))
Ok(ScalarValue::Boolean(Some(
doc.query(compile_kdl_query(filter.clone())?)
.map(|v| v.is_some())
.map_err(BuiltinError::KdlError)?,
)))
},
)
.map_err(DataFusionError::from)
}),
};
Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new(
Expand All @@ -114,49 +126,7 @@ impl BuiltinScalarUDF for KDLMatches {
}
}

fn kdl_parse_udf_args(
args: &[ColumnarValue],
) -> datafusion::error::Result<(KdlDocument, KdlQuery)> {
// parse the filter first, because it's probably shorter and
// erroring earlier would be preferable to parsing a large that we
// don't need/want.
let filter: kdl::KdlQuery = match get_nth_scalar_value(args, 1) {
Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => {
val.parse().map_err(|err: ::kdl::KdlError| {
datafusion::common::DataFusionError::Execution(err.to_string())
})?
}
Some(val) => {
return Err(datafusion::common::DataFusionError::Execution(format!(
"invalid type for KQL expression {}",
val.data_type(),
)))
}
None => {
return Err(datafusion::common::DataFusionError::Execution(
"missing KQL query".to_string(),
))
}
};

let doc: kdl::KdlDocument = match get_nth_scalar_value(args, 0) {
Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => {
val.parse().map_err(|err: ::kdl::KdlError| {
datafusion::common::DataFusionError::Execution(err.to_string())
})?
}
Some(val) => {
return Err(datafusion::common::DataFusionError::Execution(format!(
"invalid type for KDL value {}",
val.data_type(),
)))
}
None => {
return Err(datafusion::common::DataFusionError::Execution(
"invalid field for KDL".to_string(),
))
}
};

Ok((doc, filter))
#[memoize(Capacity: 256)]
fn compile_kdl_query(query: String) -> Result<KdlQuery, BuiltinError> {
query.parse().map_err(BuiltinError::KdlError)
tychoish marked this conversation as resolved.
Show resolved Hide resolved
}
Loading
Loading