Skip to content

Commit

Permalink
feat: support IS [NOT] NULL in queries
Browse files Browse the repository at this point in the history
Signed-off-by: Jim Crossley <jim@crossleys.org>
  • Loading branch information
jcrossley3 authored and Bob McWhirter committed Jun 18, 2024
1 parent 9babf65 commit 5b658cb
Showing 1 changed file with 66 additions and 51 deletions.
117 changes: 66 additions & 51 deletions common/src/db/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use sea_orm::{
sea_query, ColumnTrait, ColumnType, Condition, EntityTrait, IntoIdentity, IntoSimpleExpr,
Iterable, Order, QueryFilter, QueryOrder, Select, Value,
};
use sea_query::{BinOper, ColumnRef, Expr, IntoColumnRef, SimpleExpr};
use sea_query::{BinOper, ColumnRef, Expr, IntoColumnRef, Keyword, SimpleExpr};
use std::fmt::Display;
use std::str::FromStr;
use std::sync::OnceLock;
Expand Down Expand Up @@ -260,7 +260,7 @@ impl Filter {
caps["value"]
.split('|')
.map(decode)
.map(|s| envalue(&s, col_def.get_column_type()).map(|v| (s, v)))
.map(|s| Arg::parse(&s, col_def.get_column_type()).map(|v| (s, v)))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|(s, v)| match columns.translate(field, op, &s) {
Expand All @@ -286,9 +286,7 @@ impl Filter {
ColumnType::String(_) | ColumnType::Text => Some(Filter {
operands: Operand::Simple(
col_ref.clone(),
ValueOrSimpleExpr::Value(Value::String(Some(
decode(s).into(),
))),
Arg::Value(Value::String(Some(decode(s).into()))),
),
operator: Operator::Like,
}),
Expand Down Expand Up @@ -348,10 +346,14 @@ impl IntoCondition for Filter {
fn into_condition(self) -> Condition {
match self.operands {
Operand::Simple(col, v) => match self.operator {
Operator::Equal => Expr::col(col).binary(BinOper::Equal, v.into_simple_expr()),
Operator::NotEqual => {
Expr::col(col).binary(BinOper::NotEqual, v.into_simple_expr())
}
Operator::Equal => match v {
Arg::Null => Expr::col(col).is_null(),
v => Expr::col(col).binary(BinOper::Equal, v.into_simple_expr()),
},
Operator::NotEqual => match v {
Arg::Null => Expr::col(col).is_not_null(),
v => Expr::col(col).binary(BinOper::NotEqual, v.into_simple_expr()),
},
Operator::GreaterThan => {
Expr::col(col).binary(BinOper::GreaterThan, v.into_simple_expr())
}
Expand All @@ -365,7 +367,7 @@ impl IntoCondition for Filter {
Expr::col(col).binary(BinOper::SmallerThanOrEqual, v.into_simple_expr())
}
op @ (Operator::Like | Operator::NotLike) => {
if let ValueOrSimpleExpr::Value(v) = v {
if let Arg::Value(v) = v {
let v = format!(
"%{}%",
v.unwrap::<String>().replace('%', r"\%").replace('_', r"\_")
Expand Down Expand Up @@ -425,23 +427,64 @@ impl FromStr for Operator {
/////////////////////////////////////////////////////////////////////////

#[derive(Debug)]
enum ValueOrSimpleExpr {
enum Arg {
Value(Value),
SimpleExpr(SimpleExpr),
Null,
}

impl IntoSimpleExpr for ValueOrSimpleExpr {
impl IntoSimpleExpr for Arg {
fn into_simple_expr(self) -> SimpleExpr {
match self {
ValueOrSimpleExpr::Value(inner) => SimpleExpr::Value(inner),
ValueOrSimpleExpr::SimpleExpr(inner) => inner,
Arg::Value(inner) => SimpleExpr::Value(inner),
Arg::SimpleExpr(inner) => inner,
Arg::Null => SimpleExpr::Keyword(Keyword::Null),
}
}
}

impl Arg {
fn parse(s: &str, ct: &ColumnType) -> Result<Self, Error> {
fn err(e: impl Display) -> Error {
Error::SearchSyntax(format!(r#"conversion error: "{e}""#))
}
if s.eq_ignore_ascii_case("null") {
return Ok(Arg::Null);
}
Ok(match ct {
ColumnType::Integer => Arg::Value(Value::from(s.parse::<i32>().map_err(err)?)),
ColumnType::Decimal(_) | ColumnType::Float | ColumnType::Double => {
Arg::Value(Value::from(s.parse::<f64>().map_err(err)?))
}
ColumnType::Enum { name, .. } => Arg::SimpleExpr(SimpleExpr::AsEnum(
name.clone(),
Box::new(SimpleExpr::Value(Value::String(Some(Box::new(
s.to_owned(),
))))),
)),
ColumnType::TimestampWithTimeZone => {
if let Ok(odt) = OffsetDateTime::parse(s, &Rfc3339) {
Arg::Value(Value::from(odt))
} else if let Ok(d) = Date::parse(s, &format_description!("[year]-[month]-[day]")) {
Arg::Value(Value::from(d))
} else if let Ok(human) = from_human_time(s) {
match human {
ParseResult::DateTime(dt) => Arg::Value(Value::from(dt)),
ParseResult::Date(d) => Arg::Value(Value::from(d)),
ParseResult::Time(t) => Arg::Value(Value::from(t)),
}
} else {
Arg::Value(Value::from(s))
}
}
_ => Arg::Value(Value::from(s)),
})
}
}

#[derive(Debug)]
enum Operand {
Simple(ColumnRef, ValueOrSimpleExpr),
Simple(ColumnRef, Arg),
Composite(Vec<Filter>),
}

Expand All @@ -467,42 +510,6 @@ fn decode(s: &str) -> String {
s.replace('\x07', "&").replace('\x08', "|")
}

fn envalue(s: &str, ct: &ColumnType) -> Result<ValueOrSimpleExpr, Error> {
fn err(e: impl Display) -> Error {
Error::SearchSyntax(format!(r#"conversion error: "{e}""#))
}
Ok(match ct {
ColumnType::Integer => {
ValueOrSimpleExpr::Value(Value::from(s.parse::<i32>().map_err(err)?))
}
ColumnType::Decimal(_) | ColumnType::Float | ColumnType::Double => {
ValueOrSimpleExpr::Value(Value::from(s.parse::<f64>().map_err(err)?))
}
ColumnType::Enum { name, .. } => ValueOrSimpleExpr::SimpleExpr(SimpleExpr::AsEnum(
name.clone(),
Box::new(SimpleExpr::Value(Value::String(Some(Box::new(
s.to_owned(),
))))),
)),
ColumnType::TimestampWithTimeZone => {
if let Ok(odt) = OffsetDateTime::parse(s, &Rfc3339) {
ValueOrSimpleExpr::Value(Value::from(odt))
} else if let Ok(d) = Date::parse(s, &format_description!("[year]-[month]-[day]")) {
ValueOrSimpleExpr::Value(Value::from(d))
} else if let Ok(human) = from_human_time(s) {
match human {
ParseResult::DateTime(dt) => ValueOrSimpleExpr::Value(Value::from(dt)),
ParseResult::Date(d) => ValueOrSimpleExpr::Value(Value::from(d)),
ParseResult::Time(t) => ValueOrSimpleExpr::Value(Value::from(t)),
}
} else {
ValueOrSimpleExpr::Value(Value::from(s))
}
}
_ => ValueOrSimpleExpr::Value(Value::from(s)),
})
}

/////////////////////////////////////////////////////////////////////////
// Tests
/////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -730,6 +737,14 @@ mod tests {
where_clause("published>2023-11-03")?,
r#""advisory"."published" > '2023-11-03'"#
);
assert_eq!(
where_clause("published=null")?,
r#""advisory"."published" IS NULL"#
);
assert_eq!(
where_clause("published!=NULL")?,
r#""advisory"."published" IS NOT NULL"#
);

Ok(())
}
Expand Down

0 comments on commit 5b658cb

Please sign in to comment.