diff --git a/src/query.rs b/src/query.rs index 3189b2f..926fdb5 100644 --- a/src/query.rs +++ b/src/query.rs @@ -45,7 +45,7 @@ impl Query { /// during query execution (`execute()`, `fetch()` etc). /// /// WARNING: This means that the query must not have any extra `?`, even if - /// they are in a string literal! + /// they are in a string literal! Use `??` to have plain `?` in query. /// /// [`Serialize`]: serde::Serialize /// [`Identifier`]: crate::sql::Identifier diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 6a8e07b..c2f8803 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -21,6 +21,7 @@ pub(crate) enum SqlBuilder { pub(crate) enum Part { Arg, Fields, + Str(&'static str), Text(String), } @@ -45,20 +46,28 @@ impl fmt::Display for SqlBuilder { impl SqlBuilder { pub(crate) fn new(template: &str) -> Self { - let mut iter = template.split('?'); - let prefix = String::from(iter.next().unwrap()); - let mut parts = vec![Part::Text(prefix)]; + let mut parts = Vec::new(); + let mut rest = template; + while let Some(idx) = rest.find('?') { + if rest[idx + 1..].starts_with('?') { + parts.push(Part::Text(rest[..idx + 1].to_string())); + rest = &rest[idx + 2..]; + continue; + } else if idx != 0 { + parts.push(Part::Text(rest[..idx].to_string())); + } - for s in iter { - let text = if let Some(text) = s.strip_prefix("fields") { + rest = &rest[idx + 1..]; + if let Some(restfields) = rest.strip_prefix("fields") { parts.push(Part::Fields); - text + rest = restfields; } else { parts.push(Part::Arg); - s - }; + } + } - parts.push(Part::Text(text.into())); + if !rest.is_empty() { + parts.push(Part::Text(rest.to_string())); } SqlBuilder::InProgress(parts) @@ -96,16 +105,12 @@ impl SqlBuilder { } } - pub(crate) fn append(&mut self, suffix: &str) { + pub(crate) fn append(&mut self, suffix: &'static str) { let Self::InProgress(parts) = self else { return; }; - if let Some(Part::Text(text)) = parts.last_mut() { - text.push_str(suffix); - } else { - // Do nothing, it will fail in `finish()`. - } + parts.push(Part::Str(suffix)); } pub(crate) fn finish(mut self) -> Result { @@ -114,6 +119,7 @@ impl SqlBuilder { if let Self::InProgress(parts) = &self { for part in parts { match part { + Part::Str(text) => sql.push_str(text), Part::Text(text) => sql.push_str(text), Part::Arg => { self.error("unbound query argument"); @@ -223,6 +229,15 @@ mod tests { ); } + #[test] + fn question_escape() { + let sql = SqlBuilder::new("SELECT 1 FROM test WHERE a IN 'a??b'"); + assert_eq!( + sql.finish().unwrap(), + r"SELECT 1 FROM test WHERE a IN 'a?b'" + ); + } + #[test] fn option_as_null() { let mut sql = SqlBuilder::new("SELECT 1 FROM test WHERE a = ?");