diff --git a/examples/proxy_gluesql_example/Cargo.toml b/examples/proxy_gluesql_example/Cargo.toml index 1b30009e8..e6cf7fabf 100644 --- a/examples/proxy_gluesql_example/Cargo.toml +++ b/examples/proxy_gluesql_example/Cargo.toml @@ -15,16 +15,14 @@ futures = { version = "0.3" } async-stream = { version = "0.3" } futures-util = { version = "0.3" } +sqlparser = "0.40" sea-orm = { path = "../../", features = [ "sqlx-all", "proxy", "runtime-async-std-native-tls", "debug-print", ] } -# Since it's newer version (0.14.0) locked the chrono's version to 0.4.23, -# we need to lock it on older version too. -# Related to https://github.com/gluesql/gluesql/pull/1427 -gluesql = { version = "0.13", default-features = false, features = [ +gluesql = { version = "0.15", default-features = false, features = [ "memory-storage", ] } diff --git a/examples/proxy_gluesql_example/src/main.rs b/examples/proxy_gluesql_example/src/main.rs index 616b612fa..17c942d71 100644 --- a/examples/proxy_gluesql_example/src/main.rs +++ b/examples/proxy_gluesql_example/src/main.rs @@ -33,52 +33,95 @@ impl ProxyDatabaseTrait for ProxyDb { let sql = statement.sql.clone(); let mut ret: Vec = vec![]; - for payload in self.mem.lock().unwrap().execute(sql).unwrap().iter() { - match payload { - gluesql::prelude::Payload::Select { labels, rows } => { - for row in rows.iter() { - let mut map = BTreeMap::new(); - for (label, column) in labels.iter().zip(row.iter()) { - map.insert( - label.to_owned(), - match column { - gluesql::prelude::Value::I64(val) => { - sea_orm::Value::BigInt(Some(*val)) - } - gluesql::prelude::Value::Str(val) => { - sea_orm::Value::String(Some(Box::new(val.to_owned()))) - } - _ => unreachable!("Unsupported value: {:?}", column), - }, - ); + async_std::task::block_on(async { + for payload in self.mem.lock().unwrap().execute(sql).await.unwrap().iter() { + match payload { + gluesql::prelude::Payload::Select { labels, rows } => { + for row in rows.iter() { + let mut map = BTreeMap::new(); + for (label, column) in labels.iter().zip(row.iter()) { + map.insert( + label.to_owned(), + match column { + gluesql::prelude::Value::I64(val) => { + sea_orm::Value::BigInt(Some(*val)) + } + gluesql::prelude::Value::Str(val) => { + sea_orm::Value::String(Some(Box::new(val.to_owned()))) + } + _ => unreachable!("Unsupported value: {:?}", column), + }, + ); + } + ret.push(map.into()); } - ret.push(map.into()); } + _ => unreachable!("Unsupported payload: {:?}", payload), } - _ => unreachable!("Unsupported payload: {:?}", payload), } - } + }); Ok(ret) } fn execute(&self, statement: Statement) -> Result { - if let Some(values) = statement.values { + let sql = if let Some(values) = statement.values { // Replace all the '?' with the statement values - let mut new_sql = statement.sql.clone(); - let mark_count = new_sql.matches('?').count(); - for (i, v) in values.0.iter().enumerate() { - if i >= mark_count { - break; + use sqlparser::ast::{Expr, Value}; + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + let mut ast = Parser::parse_sql(&dialect, statement.sql.as_str()).unwrap(); + match &mut ast[0] { + sqlparser::ast::Statement::Insert { + columns, source, .. + } => { + for item in columns.iter_mut() { + item.quote_style = Some('"'); + } + + if let Some(obj) = source { + match &mut *obj.body { + sqlparser::ast::SetExpr::Values(obj) => { + for (mut item, val) in obj.rows[0].iter_mut().zip(values.0.iter()) { + match &mut item { + Expr::Value(item) => { + *item = match val { + sea_orm::Value::String(val) => { + Value::SingleQuotedString(match val { + Some(val) => val.to_string(), + None => "".to_string(), + }) + } + sea_orm::Value::BigInt(val) => Value::Number( + val.unwrap_or(0).to_string(), + false, + ), + _ => todo!(), + }; + } + _ => todo!(), + } + } + } + _ => todo!(), + } + } } - new_sql = new_sql.replacen('?', &v.to_string(), 1); + _ => todo!(), } - println!("SQL execute: {}", new_sql); - self.mem.lock().unwrap().execute(new_sql).unwrap(); + let statement = &ast[0]; + statement.to_string() } else { - self.mem.lock().unwrap().execute(statement.sql).unwrap(); - } + statement.sql + }; + + println!("SQL execute: {}", sql); + async_std::task::block_on(async { + self.mem.lock().unwrap().execute(sql).await.unwrap(); + }); Ok(ProxyExecResult { last_insert_id: 1, @@ -101,6 +144,7 @@ async fn main() { ) "#, ) + .await .unwrap(); let db = Database::connect_proxy( diff --git a/examples/proxy_surrealdb_example/Cargo.toml b/examples/proxy_surrealdb_example/Cargo.toml index 9dc5b5456..af4f75683 100644 --- a/examples/proxy_surrealdb_example/Cargo.toml +++ b/examples/proxy_surrealdb_example/Cargo.toml @@ -15,6 +15,7 @@ futures = { version = "0.3" } async-stream = { version = "0.3" } futures-util = { version = "0.3" } +sqlparser = "0.40" sea-orm = { path = "../../", features = [ "sqlx-all", "proxy", diff --git a/examples/proxy_surrealdb_example/src/main.rs b/examples/proxy_surrealdb_example/src/main.rs index 90610c18b..b23f33c5f 100644 --- a/examples/proxy_surrealdb_example/src/main.rs +++ b/examples/proxy_surrealdb_example/src/main.rs @@ -28,21 +28,58 @@ struct ProxyDb { impl ProxyDatabaseTrait for ProxyDb { fn query(&self, statement: Statement) -> Result, DbErr> { println!("SQL query: {:?}", statement); - let sql = statement.sql.clone(); let mut ret = async_std::task::block_on(async { // Surrealdb's grammar is not compatible with sea-orm's // so we need to remove the extra clauses // from "SELECT `from`.`col` FROM `from` WHERE `from`.`col` = xx" // to "SELECT `col` FROM `from` WHERE `col` = xx" - // Get the first index of "FROM" - let from_index = sql.find("FROM").unwrap(); - // Get the name after "FROM" - let from_name = sql[from_index + 5..].split(' ').next().unwrap(); - // Delete the name before all the columns - let new_sql = sql.replace(&format!("{}.", from_name), ""); + use sqlparser::ast::{Expr, SelectItem, SetExpr, TableFactor}; + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; - self.mem.query(new_sql).await + let dialect = GenericDialect {}; + let mut ast = Parser::parse_sql(&dialect, statement.sql.as_str()).unwrap(); + match &mut ast[0] { + sqlparser::ast::Statement::Query(query) => match &mut *query.body { + SetExpr::Select(body) => { + body.projection.iter_mut().for_each(|item| { + match item { + SelectItem::UnnamedExpr(expr) => { + match expr { + Expr::CompoundIdentifier(idents) => { + // Remove the head of the identifier + // e.g. `from`.`col` -> `col` + let ident = idents.pop().unwrap(); + *expr = Expr::Identifier(ident); + } + _ => todo!(), + } + } + _ => todo!(), + } + }); + body.from.iter_mut().for_each(|item| { + match &mut item.relation { + TableFactor::Table { name, .. } => { + // Remove the head of the identifier + // e.g. `from`.`col` -> `col` + let ident = name.0.pop().unwrap(); + name.0 = vec![ident]; + } + _ => todo!(), + } + }); + } + _ => todo!(), + }, + _ => todo!(), + }; + + let statement = &ast[0]; + let sql = statement.to_string(); + println!("SQL: {}", sql); + self.mem.query(sql).await }) .unwrap(); @@ -116,17 +153,65 @@ impl ProxyDatabaseTrait for ProxyDb { async_std::task::block_on(async { if let Some(values) = statement.values { // Replace all the '?' with the statement values - let mut new_sql = statement.sql.clone(); - let mark_count = new_sql.matches('?').count(); - for (i, v) in values.0.iter().enumerate() { - if i >= mark_count { - break; + use sqlparser::ast::{Expr, Value}; + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + let mut ast = Parser::parse_sql(&dialect, statement.sql.as_str()).unwrap(); + match &mut ast[0] { + sqlparser::ast::Statement::Insert { + table_name, + columns, + source, + .. + } => { + // Replace the table name's quote style + table_name.0[0].quote_style = Some('`'); + + // Replace all the column names' quote style + for item in columns.iter_mut() { + item.quote_style = Some('`'); + } + + // Convert the values to sea-orm's format + if let Some(obj) = source { + match &mut *obj.body { + sqlparser::ast::SetExpr::Values(obj) => { + for (mut item, val) in + obj.rows[0].iter_mut().zip(values.0.iter()) + { + match &mut item { + Expr::Value(item) => { + *item = match val { + sea_orm::Value::String(val) => { + Value::SingleQuotedString(match val { + Some(val) => val.to_string(), + None => "".to_string(), + }) + } + sea_orm::Value::BigInt(val) => Value::Number( + val.unwrap_or(0).to_string(), + false, + ), + _ => todo!(), + }; + } + _ => todo!(), + } + } + } + _ => todo!(), + } + } } - new_sql = new_sql.replacen('?', &v.to_string(), 1); + _ => todo!(), } - println!("SQL execute: {}", new_sql); - self.mem.query(new_sql).await + let statement = &ast[0]; + let sql = statement.to_string(); + println!("SQL: {}", sql); + self.mem.query(sql).await } else { self.mem.query(statement.sql).await }