From fb43250b9ec3b7cb7cd014330de479340c8e328c Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 3 Feb 2023 21:01:30 +0800 Subject: [PATCH 1/9] wrap query language enum Signed-off-by: Ruihang Xia --- src/datanode/src/instance/grpc.rs | 5 +++-- src/datanode/src/instance/sql.rs | 8 +++++--- src/datanode/src/sql.rs | 4 ++-- src/query/src/datafusion.rs | 10 +++++----- src/query/src/parser.rs | 18 +++++++++++++++--- src/query/tests/argmax_test.rs | 4 ++-- src/query/tests/argmin_test.rs | 4 ++-- src/query/tests/function.rs | 4 ++-- src/query/tests/mean_test.rs | 4 ++-- src/query/tests/my_sum_udaf_example.rs | 4 ++-- src/query/tests/percentile_test.rs | 6 +++--- src/query/tests/polyval_test.rs | 4 ++-- src/query/tests/query_engine_test.rs | 9 +++++---- src/query/tests/scipy_stats_norm_cdf_test.rs | 4 ++-- src/query/tests/scipy_stats_norm_pdf.rs | 4 ++-- src/query/tests/time_range_filter_test.rs | 4 +++- src/script/src/python/coprocessor.rs | 7 ++++--- src/script/src/python/engine.rs | 4 ++-- src/script/src/table.rs | 4 ++-- src/servers/tests/mod.rs | 4 ++-- 20 files changed, 67 insertions(+), 48 deletions(-) diff --git a/src/datanode/src/instance/grpc.rs b/src/datanode/src/instance/grpc.rs index 1c5ba0a40a4e..97815bf3b4b9 100644 --- a/src/datanode/src/instance/grpc.rs +++ b/src/datanode/src/instance/grpc.rs @@ -18,7 +18,7 @@ use api::v1::query_request::Query; use api::v1::{CreateDatabaseExpr, DdlRequest, InsertRequest}; use async_trait::async_trait; use common_query::Output; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::plan::LogicalPlan; use servers::query_handler::grpc::GrpcQueryHandler; use session::context::QueryContextRef; @@ -52,7 +52,8 @@ impl Instance { async fn handle_query(&self, query: Query, ctx: QueryContextRef) -> Result { Ok(match query { Query::Sql(sql) => { - let stmt = QueryLanguageParser::parse_sql(&sql).context(ExecuteSqlSnafu)?; + let stmt = + QueryLanguageParser::parse(QueryLanguage::Sql(sql)).context(ExecuteSqlSnafu)?; self.execute_stmt(stmt, ctx).await? } Query::LogicalPlan(plan) => self.execute_logical(plan).await?, diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index d7862f876201..81c566f84cd3 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -18,7 +18,7 @@ use common_query::Output; use common_recordbatch::RecordBatches; use common_telemetry::logging::info; use common_telemetry::timer; -use query::parser::{QueryLanguageParser, QueryStatement}; +use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use servers::error as server_error; use servers::promql::PromqlHandler; use servers::query_handler::sql::SqlQueryHandler; @@ -160,12 +160,14 @@ impl Instance { } pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { - let stmt = QueryLanguageParser::parse_sql(sql).context(ExecuteSqlSnafu)?; + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql.to_owned())) + .context(ExecuteSqlSnafu)?; self.execute_stmt(stmt, query_ctx).await } pub async fn execute_promql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { - let stmt = QueryLanguageParser::parse_promql(sql).context(ExecuteSqlSnafu)?; + let stmt = QueryLanguageParser::parse(QueryLanguage::Promql(sql.to_owned())) + .context(ExecuteSqlSnafu)?; self.execute_stmt(stmt, query_ctx).await } } diff --git a/src/datanode/src/sql.rs b/src/datanode/src/sql.rs index e5d8ad74d9b7..6a74ecec211b 100644 --- a/src/datanode/src/sql.rs +++ b/src/datanode/src/sql.rs @@ -143,7 +143,7 @@ mod tests { use mito::engine::MitoEngine; use object_store::services::fs::Builder; use object_store::ObjectStore; - use query::parser::{QueryLanguageParser, QueryStatement}; + use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use query::QueryEngineFactory; use sql::statements::statement::Statement; use storage::config::EngineConfig as StorageEngineConfig; @@ -241,7 +241,7 @@ mod tests { let query_engine = factory.query_engine(); let sql_handler = SqlHandler::new(table_engine, catalog_list.clone(), query_engine.clone()); - let stmt = match QueryLanguageParser::parse_sql(sql).unwrap() { + let stmt = match QueryLanguageParser::parse(QueryLanguage::Sql(sql.to_owned())).unwrap() { QueryStatement::Sql(Statement::Insert(i)) => i, _ => { unreachable!() diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index f101a801dd9f..d92f8af213cf 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -277,7 +277,7 @@ mod tests { use session::context::QueryContext; use table::table::numbers::NumbersTable; - use crate::parser::QueryLanguageParser; + use crate::parser::{QueryLanguage, QueryLanguageParser}; use crate::query_engine::{QueryEngineFactory, QueryEngineRef}; fn create_test_engine() -> QueryEngineRef { @@ -301,9 +301,9 @@ mod tests { #[test] fn test_sql_to_plan() { let engine = create_test_engine(); - let sql = "select sum(number) from numbers limit 20"; + let sql = "select sum(number) from numbers limit 20".to_string(); - let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); @@ -321,9 +321,9 @@ mod tests { #[tokio::test] async fn test_execute() { let engine = create_test_engine(); - let sql = "select sum(number) from numbers limit 20"; + let sql = "select sum(number) from numbers limit 20".to_string(); - let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/src/parser.rs b/src/query/src/parser.rs index 2e697aa9f6ab..47069bf2145b 100644 --- a/src/query/src/parser.rs +++ b/src/query/src/parser.rs @@ -27,6 +27,12 @@ use sql::statements::statement::Statement; use crate::error::{MultipleStatementsSnafu, QueryParseSnafu, Result}; use crate::metric::{METRIC_PARSE_PROMQL_ELAPSED, METRIC_PARSE_SQL_ELAPSED}; +#[derive(Debug, Clone)] +pub enum QueryLanguage { + Sql(String), + Promql(String), +} + #[derive(Debug, Clone)] pub enum QueryStatement { Sql(Statement), @@ -36,7 +42,14 @@ pub enum QueryStatement { pub struct QueryLanguageParser {} impl QueryLanguageParser { - pub fn parse_sql(sql: &str) -> Result { + pub fn parse(query: QueryLanguage) -> Result { + match query { + QueryLanguage::Sql(sql) => Self::parse_sql(&sql), + QueryLanguage::Promql(promql) => Self::parse_promql(&promql), + } + } + + fn parse_sql(sql: &str) -> Result { let _timer = timer!(METRIC_PARSE_SQL_ELAPSED); let mut statement = ParserContext::create_with_dialect(sql, &GenericDialect {}) .map_err(BoxedError::new) @@ -53,8 +66,7 @@ impl QueryLanguageParser { } } - // TODO(ruihang): implement this method when parser is ready. - pub fn parse_promql(promql: &str) -> Result { + fn parse_promql(promql: &str) -> Result { let _timer = timer!(METRIC_PARSE_PROMQL_ELAPSED); let prom_expr = promql_parser::parser::parse(promql) diff --git a/src/query/tests/argmax_test.rs b/src/query/tests/argmax_test.rs index a2b45cb49bf0..bf220b392eff 100644 --- a/src/query/tests/argmax_test.rs +++ b/src/query/tests/argmax_test.rs @@ -23,7 +23,7 @@ use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::types::WrapperType; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngine; use session::context::QueryContext; @@ -84,7 +84,7 @@ async fn execute_argmax<'a>( engine: Arc, ) -> RecordResult> { let sql = format!("select ARGMAX({column_name}) as argmax from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/argmin_test.rs b/src/query/tests/argmin_test.rs index 9ea9066cbca8..2e6628cc5f51 100644 --- a/src/query/tests/argmin_test.rs +++ b/src/query/tests/argmin_test.rs @@ -23,7 +23,7 @@ use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::types::WrapperType; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngine; use session::context::QueryContext; @@ -84,7 +84,7 @@ async fn execute_argmin<'a>( engine: Arc, ) -> RecordResult> { let sql = format!("select argmin({column_name}) as argmin from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/function.rs b/src/query/tests/function.rs index bebfad49ae69..cbb6fed2ae66 100644 --- a/src/query/tests/function.rs +++ b/src/query/tests/function.rs @@ -26,7 +26,7 @@ use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::types::WrapperType; use datatypes::vectors::Helper; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::query_engine::QueryEngineFactory; use query::QueryEngine; use rand::Rng; @@ -83,7 +83,7 @@ where T: WrapperType, { let sql = format!("SELECT {column_name} FROM {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/mean_test.rs b/src/query/tests/mean_test.rs index fdc682517a0f..7546833cbdc1 100644 --- a/src/query/tests/mean_test.rs +++ b/src/query/tests/mean_test.rs @@ -26,7 +26,7 @@ use datatypes::value::OrderedFloat; use format_num::NumberFormat; use num_traits::AsPrimitive; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngine; use session::context::QueryContext; @@ -80,7 +80,7 @@ async fn execute_mean<'a>( engine: Arc, ) -> RecordResult> { let sql = format!("select MEAN({column_name}) as mean from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index 5e76fd6f1d09..868a5e9e30c0 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -33,7 +33,7 @@ use datatypes::vectors::Helper; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngineFactory; use session::context::QueryContext; use table::test_util::MemTable; @@ -219,7 +219,7 @@ where ))); let sql = format!("select MY_SUM({column_name}) as my_sum from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/percentile_test.rs b/src/query/tests/percentile_test.rs index 801aa22dc24a..74198f269fd9 100644 --- a/src/query/tests/percentile_test.rs +++ b/src/query/tests/percentile_test.rs @@ -27,7 +27,7 @@ use datatypes::vectors::Int32Vector; use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::{QueryEngine, QueryEngineFactory}; use session::context::QueryContext; use table::test_util::MemTable; @@ -53,7 +53,7 @@ async fn test_percentile_aggregator() -> Result<()> { async fn test_percentile_correctness() -> Result<()> { let engine = create_correctness_engine(); let sql = String::from("select PERCENTILE(corr_number,88.0) as percentile from corr_numbers"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); @@ -98,7 +98,7 @@ async fn execute_percentile<'a>( engine: Arc, ) -> RecordResult> { let sql = format!("select PERCENTILE({column_name},50.0) as percentile from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/polyval_test.rs b/src/query/tests/polyval_test.rs index acf2c4d23618..0277e480f257 100644 --- a/src/query/tests/polyval_test.rs +++ b/src/query/tests/polyval_test.rs @@ -23,7 +23,7 @@ use datatypes::prelude::*; use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngine; use session::context::QueryContext; @@ -80,7 +80,7 @@ async fn execute_polyval<'a>( engine: Arc, ) -> RecordResult> { let sql = format!("select POLYVAL({column_name}, 0) as polyval from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index d4798056025f..667fa70c7d08 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -34,7 +34,7 @@ use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::UInt32Vector; use query::error::{QueryExecutionSnafu, Result}; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::plan::LogicalPlan; use query::query_engine::QueryEngineFactory; use session::context::QueryContext; @@ -145,9 +145,10 @@ async fn test_udf() -> Result<()> { engine.register_udf(udf); - let stmt = - QueryLanguageParser::parse_sql("select my_pow(number, number) as p from numbers limit 10") - .unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql( + "select my_pow(number, number) as p from numbers limit 10".to_string(), + )) + .unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/scipy_stats_norm_cdf_test.rs b/src/query/tests/scipy_stats_norm_cdf_test.rs index 08e01b1a7130..6e03efe3b8b8 100644 --- a/src/query/tests/scipy_stats_norm_cdf_test.rs +++ b/src/query/tests/scipy_stats_norm_cdf_test.rs @@ -22,7 +22,7 @@ use datatypes::for_all_primitive_types; use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngine; use session::context::QueryContext; use statrs::distribution::{ContinuousCDF, Normal}; @@ -79,7 +79,7 @@ async fn execute_scipy_stats_norm_cdf<'a>( let sql = format!( "select SCIPYSTATSNORMCDF({column_name},2.0) as scipy_stats_norm_cdf from {table_name}", ); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/scipy_stats_norm_pdf.rs b/src/query/tests/scipy_stats_norm_pdf.rs index 6e8994c4e6ea..bc0ccb1f1d7c 100644 --- a/src/query/tests/scipy_stats_norm_pdf.rs +++ b/src/query/tests/scipy_stats_norm_pdf.rs @@ -22,7 +22,7 @@ use datatypes::for_all_primitive_types; use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngine; use session::context::QueryContext; use statrs::distribution::{Continuous, Normal}; @@ -79,7 +79,7 @@ async fn execute_scipy_stats_norm_pdf<'a>( let sql = format!( "select SCIPYSTATSNORMPDF({column_name},2.0) as scipy_stats_norm_pdf from {table_name}" ); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) .unwrap(); diff --git a/src/query/tests/time_range_filter_test.rs b/src/query/tests/time_range_filter_test.rs index aed523c3404f..aafa18c37753 100644 --- a/src/query/tests/time_range_filter_test.rs +++ b/src/query/tests/time_range_filter_test.rs @@ -26,6 +26,7 @@ use common_time::Timestamp; use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; use datatypes::vectors::{Int64Vector, TimestampMillisecondVector}; +use query::parser::QueryLanguage; use query::QueryEngineRef; use session::context::QueryContext; use table::metadata::{FilterPushDownType, TableInfoRef}; @@ -126,7 +127,8 @@ struct TimeRangeTester { impl TimeRangeTester { async fn check(&self, sql: &str, expect: TimestampRange) { - let stmt = query::parser::QueryLanguageParser::parse_sql(sql).unwrap(); + let stmt = + query::parser::QueryLanguageParser::parse(QueryLanguage::Sql(sql.to_owned())).unwrap(); let _ = self .engine .execute( diff --git a/src/script/src/python/coprocessor.rs b/src/script/src/python/coprocessor.rs index c10e5dc51385..78fa940fc7a8 100644 --- a/src/script/src/python/coprocessor.rs +++ b/src/script/src/python/coprocessor.rs @@ -27,7 +27,7 @@ use datatypes::arrow::compute; use datatypes::data_type::{ConcreteDataType, DataType}; use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; use datatypes::vectors::{Helper, VectorRef}; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngine; use rustpython_compiler_core::CodeObject; use rustpython_vm as vm; @@ -290,7 +290,7 @@ fn set_items_in_scope( } /// The coprocessor function accept a python script and a Record Batch: -/// ## What it does +/// # What it does /// 1. it take a python script and a [`RecordBatch`], extract columns and annotation info according to `args` given in decorator in python script /// 2. execute python code and return a vector or a tuple of vector, /// 3. the returning vector(s) is assembled into a new [`RecordBatch`] according to `returns` in python decorator and return to caller @@ -369,7 +369,8 @@ impl PyQueryEngine { let query = self.inner.0.upgrade(); let thread_handle = std::thread::spawn(move || -> std::result::Result<_, String> { if let Some(engine) = query { - let stmt = QueryLanguageParser::parse_sql(s.as_str()).map_err(|e| e.to_string())?; + let stmt = + QueryLanguageParser::parse(QueryLanguage::Sql(s)).map_err(|e| e.to_string())?; let plan = engine .statement_to_plan(stmt, Default::default()) .map_err(|e| e.to_string())?; diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 7a68a63e6bf8..f6c8f879f3d1 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -30,7 +30,7 @@ use datafusion_expr::Volatility; use datatypes::schema::{ColumnSchema, SchemaRef}; use datatypes::vectors::VectorRef; use futures::Stream; -use query::parser::{QueryLanguageParser, QueryStatement}; +use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use query::QueryEngineRef; use session::context::QueryContext; use snafu::{ensure, ResultExt}; @@ -220,7 +220,7 @@ impl Script for PyScript { async fn execute(&self, _ctx: EvalContext) -> Result { if let Some(sql) = &self.copr.deco_args.sql { - let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql.clone())).unwrap(); ensure!( matches!(stmt, QueryStatement::Sql(Statement::Query { .. })), error::UnsupportedSqlSnafu { sql } diff --git a/src/script/src/table.rs b/src/script/src/table.rs index e0d88655719c..8bd03a2a530c 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -25,7 +25,7 @@ use common_time::util; use datatypes::prelude::{ConcreteDataType, ScalarVector}; use datatypes::schema::{ColumnSchema, Schema, SchemaBuilder}; use datatypes::vectors::{StringVector, TimestampMillisecondVector, Vector, VectorRef}; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::QueryEngineRef; use session::context::QueryContext; use snafu::{ensure, OptionExt, ResultExt}; @@ -145,7 +145,7 @@ impl ScriptsTable { // TODO(dennis): we use sql to find the script, the better way is use a function // such as `find_record_by_primary_key` in table_engine. let sql = format!("select script from {} where name='{}'", self.name(), name); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql)).unwrap(); let plan = self .query_engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 15af6e9706eb..566e8f578d5f 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -20,7 +20,7 @@ use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaPr use catalog::{CatalogList, CatalogProvider, SchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; -use query::parser::QueryLanguageParser; +use query::parser::{QueryLanguage, QueryLanguageParser}; use query::{QueryEngineFactory, QueryEngineRef}; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; @@ -59,7 +59,7 @@ impl SqlQueryHandler for DummyInstance { type Error = Error; async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - let stmt = QueryLanguageParser::parse_sql(query).unwrap(); + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(query.to_owned())).unwrap(); let plan = self .query_engine .statement_to_plan(stmt, query_ctx) From b9722ac0e561a4e0ed6ded5490af59e83ce70665 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 4 Feb 2023 11:06:25 +0800 Subject: [PATCH 2/9] chore: use rustflag to express lint configs Signed-off-by: Ruihang Xia --- .cargo/config.toml | 10 ++++++++++ .github/workflows/develop.yml | 2 +- Makefile | 2 +- src/servers/src/query_handler/sql.rs | 5 +++++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 3c32d251c5a5..1dbdc6413b48 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,12 @@ [target.aarch64-unknown-linux-gnu] linker = "aarch64-linux-gnu-gcc" + +[build] +rustflags = [ + # lints + # TODO: use lint configuration in cargo https://github.com/rust-lang/cargo/issues/5034 + "-Wclippy::print_stdout", + "-Wclippy::print_stderr", + # false positive: https://github.com/rust-lang/rust/issues/51443#issuecomment-1374847313 + "-Awhere_clauses_object_safety", +] diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index 3987b8835689..b78e33f3eea5 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -182,7 +182,7 @@ jobs: - name: Rust Cache uses: Swatinem/rust-cache@v2 - name: Run cargo clippy - run: cargo clippy --workspace --all-targets -- -D warnings -D clippy::print_stdout -D clippy::print_stderr + run: cargo clippy --workspace --all-targets -- -D warnings coverage: if: github.event.pull_request.draft == false diff --git a/Makefile b/Makefile index 28938acf3238..bd1df816ff36 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ check: ## Cargo check all the targets. .PHONY: clippy clippy: ## Check clippy rules. - cargo clippy --workspace --all-targets -- -D warnings -D clippy::print_stdout -D clippy::print_stderr + cargo clippy --workspace --all-targets -- -D warnings .PHONY: fmt-check fmt-check: ## Check code format. diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index d394e84d645f..a974d74fd7a0 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -52,6 +52,11 @@ pub trait SqlQueryHandler { catalog: &str, schema: &str, ) -> std::result::Result; + + async fn foo(&self) -> std::result::Result<(), Self::Error> { + println!(""); + unimplemented!() + } } pub struct ServerSqlQueryHandlerAdaptor(SqlQueryHandlerRef); From 486aa6797a0fbf55dc1a79052b38f438d5b64989 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 6 Feb 2023 14:27:34 +0800 Subject: [PATCH 3/9] rename do_statement_query to statement_query and change its parameter from Statement to QueryStatement Add parse_multiple method to QueryLanguageParser Signed-off-by: Ruihang Xia --- src/datanode/src/instance/sql.rs | 7 +- src/frontend/src/error.rs | 10 ++- src/frontend/src/instance.rs | 100 +++++++++++----------- src/frontend/src/instance/distributed.rs | 60 +++++++------ src/frontend/src/instance/standalone.rs | 8 +- src/query/src/parser.rs | 19 +++- src/servers/src/http.rs | 4 +- src/servers/src/interceptor.rs | 40 ++++----- src/servers/src/query_handler/sql.rs | 17 ++-- src/servers/tests/http/influxdb_test.rs | 4 +- src/servers/tests/http/opentsdb_test.rs | 4 +- src/servers/tests/http/prometheus_test.rs | 4 +- src/servers/tests/interceptor.rs | 25 ++++-- src/servers/tests/mod.rs | 4 +- 14 files changed, 170 insertions(+), 136 deletions(-) diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 81c566f84cd3..91f4c82c216c 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -224,14 +224,13 @@ impl SqlQueryHandler for Instance { vec![result] } - async fn do_statement_query( + async fn statement_query( &self, - stmt: Statement, + stmt: QueryStatement, query_ctx: QueryContextRef, ) -> Result { let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - self.execute_stmt(QueryStatement::Sql(stmt), query_ctx) - .await + self.execute_stmt(stmt, query_ctx).await } fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 7624021e8049..3fe8bb203198 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -56,6 +56,12 @@ pub enum Error { source: sql::error::Error, }, + #[snafu(display("Failed to parse query, source: {}", source))] + ParseQuery { + #[snafu(backtrace)] + source: query::error::Error, + }, + #[snafu(display("Column datatype error, source: {}", source))] ColumnDataType { #[snafu(backtrace)] @@ -407,7 +413,9 @@ impl ErrorExt for Error { | Error::FindNewColumnsOnInsertion { source } => source.status_code(), Error::PrimaryKeyNotFound { .. } => StatusCode::InvalidArguments, - Error::ExecuteStatement { source, .. } => source.status_code(), + Error::ExecuteStatement { source, .. } | Error::ParseQuery { source } => { + source.status_code() + } Error::MissingMetasrvOpts { .. } => StatusCode::InvalidArguments, Error::AlterExprToRequest { source, .. } => source.status_code(), Error::LeaderNotFound { .. } => StatusCode::StorageUnavailable, diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 6a6010d15326..85a157230c58 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -41,6 +41,7 @@ use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_client::MetaClientOpts; use partition::manager::PartitionRuleManager; use partition::route::TableRoutes; +use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use servers::error as server_error; use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef}; use servers::promql::{PromqlHandler, PromqlHandlerRef}; @@ -52,8 +53,6 @@ use servers::query_handler::{ }; use session::context::QueryContextRef; use snafu::prelude::*; -use sql::dialect::GenericDialect; -use sql::parser::ParserContext; use sql::statements::statement::Statement; use crate::catalog::FrontendCatalogManager; @@ -365,27 +364,15 @@ impl FrontendInstance for Instance { } } -fn parse_stmt(sql: &str) -> Result> { - ParserContext::create_with_dialect(sql, &GenericDialect {}).context(error::ParseSqlSnafu) -} - impl Instance { - async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { + async fn query_statement( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> Result { // TODO(sunng87): provide a better form to log or track statement - let query = &format!("{:?}", &stmt); match stmt.clone() { - Statement::CreateDatabase(_) - | Statement::ShowDatabases(_) - | Statement::CreateTable(_) - | Statement::ShowTables(_) - | Statement::DescribeTable(_) - | Statement::Explain(_) - | Statement::Query(_) - | Statement::Insert(_) - | Statement::Alter(_) => { - return self.sql_handler.do_statement_query(stmt, query_ctx).await; - } - Statement::DropTable(drop_stmt) => { + QueryStatement::Sql(Statement::DropTable(drop_stmt)) => { let (catalog_name, schema_name, table_name) = table_idents_to_full_name(drop_stmt.table_name(), query_ctx.clone()) .map_err(BoxedError::new) @@ -405,8 +392,12 @@ impl Instance { ) .await; } - Statement::ShowCreateTable(_) => error::NotSupportedSnafu { feat: query }.fail(), - Statement::Use(db) => self.handle_use(db, query_ctx), + QueryStatement::Sql(Statement::ShowCreateTable(_)) => error::NotSupportedSnafu { + feat: format!("{:?}", &stmt), + } + .fail(), + QueryStatement::Sql(Statement::Use(db)) => self.handle_use(db, query_ctx), + _ => self.sql_handler.statement_query(stmt, query_ctx).await, } } } @@ -417,14 +408,20 @@ impl SqlQueryHandler for Instance { async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { let query_interceptor = self.plugins.get::>(); + let query = QueryLanguage::Sql(query.to_string()); let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) { Ok(q) => q, Err(e) => return vec![Err(e)], }; - match parse_stmt(query.as_ref()) - .and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone())) - { + match QueryLanguageParser::parse_multiple(query) + .context(error::ParseQuerySnafu) + .and_then(|stmts| { + stmts + .into_iter() + .map(|stmt| query_interceptor.post_parsing(stmt, query_ctx.clone())) + .collect::>>() + }) { Ok(stmts) => { let mut results = Vec::with_capacity(stmts.len()); for stmt in stmts { @@ -470,9 +467,9 @@ impl SqlQueryHandler for Instance { } } - async fn do_statement_query( + async fn statement_query( &self, - stmt: Statement, + stmt: QueryStatement, query_ctx: QueryContextRef, ) -> Result { let query_interceptor = self.plugins.get::>(); @@ -535,7 +532,6 @@ impl PromqlHandler for Instance { #[cfg(test)] mod tests { - use std::borrow::Cow; use std::sync::atomic::AtomicU32; use session::context::QueryContext; @@ -651,27 +647,34 @@ mod tests { fn pre_parsing<'a>( &self, - query: &'a str, + query: QueryLanguage, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result { self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - assert!(query.starts_with("CREATE TABLE demo")); - Ok(Cow::Borrowed(query)) + if let QueryLanguage::Sql(sql) = &query { + assert!(sql.starts_with("CREATE TABLE demo")); + } else { + panic!("unexpected query language"); + } + Ok(query) } fn post_parsing( &self, - statements: Vec, + statement: QueryStatement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result { self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - assert!(matches!(statements[0], Statement::CreateTable(_))); - Ok(statements) + assert!(matches!( + statement, + QueryStatement::Sql(Statement::CreateTable(_)) + )); + Ok(statement) } fn pre_execute( &self, - _statement: &Statement, + _statement: &QueryStatement, _plan: Option<&query::plan::LogicalPlan>, _query_ctx: QueryContextRef, ) -> Result<()> { @@ -737,21 +740,20 @@ mod tests { fn post_parsing( &self, - statements: Vec, + statement: QueryStatement, _query_ctx: QueryContextRef, - ) -> Result> { - for s in &statements { - match s { - Statement::CreateDatabase(_) | Statement::ShowDatabases(_) => { - return Err(Error::NotSupported { - feat: "Database operations".to_owned(), - }) - } - _ => {} + ) -> Result { + match statement { + QueryStatement::Sql(Statement::CreateDatabase(_)) + | QueryStatement::Sql(Statement::ShowDatabases(_)) => { + return Err(Error::NotSupported { + feat: "Database operations".to_owned(), + }) } + _ => {} } - Ok(statements) + Ok(statement) } } @@ -794,7 +796,7 @@ mod tests { unreachable!(); } - let sql = r#"SELECT 1; SHOW DATABASES"#; + let sql = r#"SHOW DATABASES"#; if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) .await .remove(0) diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index c7d7f8c13b69..8e2fe525ac1b 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -37,7 +37,7 @@ use meta_client::rpc::{ TableName, TableRoute, }; use partition::partition::{PartitionBound, PartitionDef}; -use query::parser::QueryStatement; +use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; use servers::query_handler::sql::SqlQueryHandler; @@ -54,12 +54,11 @@ use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ self, AlterExprToRequestSnafu, CatalogEntrySerdeSnafu, CatalogNotFoundSnafu, CatalogSnafu, - ColumnDataTypeSnafu, DeserializePartitionSnafu, ParseSqlSnafu, PrimaryKeyNotFoundSnafu, - RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, StartMetaClientSnafu, - TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, + ColumnDataTypeSnafu, DeserializePartitionSnafu, ParseQuerySnafu, ParseSqlSnafu, + PrimaryKeyNotFoundSnafu, RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, + StartMetaClientSnafu, TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, }; use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory}; -use crate::instance::parse_stmt; use crate::sql::insert_to_request; #[derive(Clone)] @@ -142,7 +141,7 @@ impl DistInstance { Ok(Output::AffectedRows(0)) } - async fn handle_statement( + async fn handle_sql_statement( &self, stmt: Statement, query_ctx: QueryContextRef, @@ -206,25 +205,31 @@ impl DistInstance { .context(error::ExecuteStatementSnafu) } - async fn handle_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Vec> { - let stmts = parse_stmt(sql); - match stmts { - Ok(stmts) => { - let mut results = Vec::with_capacity(stmts.len()); - - for stmt in stmts { - let result = self.handle_statement(stmt, query_ctx.clone()).await; - let is_err = result.is_err(); - - results.push(result); - - if is_err { - break; - } - } + async fn handle_promql_statement( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> Result { + let plan = self + .query_engine + .statement_to_plan(stmt, query_ctx) + .context(error::ExecuteStatementSnafu {})?; + self.query_engine + .execute(&plan) + .await + .context(error::ExecuteStatementSnafu) + } - results + async fn handle_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Vec> { + let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql.to_string())) + .context(ParseQuerySnafu); + match stmt { + Ok(stmt) => { + let result = self.statement_query(stmt, query_ctx.clone()).await; + vec![result] } + + // results Err(e) => vec![Err(e)], } } @@ -403,12 +408,15 @@ impl SqlQueryHandler for DistInstance { unimplemented!() } - async fn do_statement_query( + async fn statement_query( &self, - stmt: Statement, + stmt: QueryStatement, query_ctx: QueryContextRef, ) -> Result { - self.handle_statement(stmt, query_ctx).await + match stmt { + QueryStatement::Sql(stmt) => self.handle_sql_statement(stmt, query_ctx).await, + QueryStatement::Promql(_) => self.handle_promql_statement(stmt, query_ctx).await, + } } fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { diff --git a/src/frontend/src/instance/standalone.rs b/src/frontend/src/instance/standalone.rs index 6138727e949d..e113440ad4a5 100644 --- a/src/frontend/src/instance/standalone.rs +++ b/src/frontend/src/instance/standalone.rs @@ -18,11 +18,11 @@ use api::v1::greptime_request::Request as GreptimeRequest; use async_trait::async_trait; use common_query::Output; use datanode::error::Error as DatanodeError; +use query::parser::QueryStatement; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; use servers::query_handler::sql::{SqlQueryHandler, SqlQueryHandlerRef}; use session::context::QueryContextRef; use snafu::ResultExt; -use sql::statements::statement::Statement; use crate::error::{self, Result}; @@ -55,13 +55,13 @@ impl SqlQueryHandler for StandaloneSqlQueryHandler { unimplemented!() } - async fn do_statement_query( + async fn statement_query( &self, - stmt: Statement, + stmt: QueryStatement, query_ctx: QueryContextRef, ) -> Result { self.0 - .do_statement_query(stmt, query_ctx) + .statement_query(stmt, query_ctx) .await .context(error::InvokeDatanodeSnafu) } diff --git a/src/query/src/parser.rs b/src/query/src/parser.rs index 47069bf2145b..05f9148384b9 100644 --- a/src/query/src/parser.rs +++ b/src/query/src/parser.rs @@ -27,7 +27,7 @@ use sql::statements::statement::Statement; use crate::error::{MultipleStatementsSnafu, QueryParseSnafu, Result}; use crate::metric::{METRIC_PARSE_PROMQL_ELAPSED, METRIC_PARSE_SQL_ELAPSED}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum QueryLanguage { Sql(String), Promql(String), @@ -49,6 +49,23 @@ impl QueryLanguageParser { } } + pub fn parse_multiple(query: QueryLanguage) -> Result> { + match query { + QueryLanguage::Sql(sql) => Self::parse_multiple_sql(&sql), + QueryLanguage::Promql(promql) => Self::parse_promql(&promql).map(|stmt| vec![stmt]), + } + } + + fn parse_multiple_sql(query: &str) -> Result> { + let _timer = timer!(METRIC_PARSE_SQL_ELAPSED); + let mut statements = ParserContext::create_with_dialect(query, &GenericDialect {}) + .map_err(BoxedError::new) + .context(QueryParseSnafu { + query: query.to_string(), + })?; + Ok(statements.drain(..).map(QueryStatement::Sql).collect()) + } + fn parse_sql(sql: &str) -> Result { let _timer = timer!(METRIC_PARSE_SQL_ELAPSED); let mut statement = ParserContext::create_with_dialect(sql, &GenericDialect {}) diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 9b0eb8b7c277..237f44d9c945 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -590,9 +590,9 @@ mod test { unimplemented!() } - async fn do_statement_query( + async fn statement_query( &self, - _stmt: sql::statements::statement::Statement, + _stmt: query::parser::QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/src/interceptor.rs b/src/servers/src/interceptor.rs index fa1cf83862a1..5076ff0333a5 100644 --- a/src/servers/src/interceptor.rs +++ b/src/servers/src/interceptor.rs @@ -12,45 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::sync::Arc; use common_error::prelude::ErrorExt; use common_query::Output; +use query::parser::{QueryLanguage, QueryStatement}; use query::plan::LogicalPlan; use session::context::QueryContextRef; -use sql::statements::statement::Statement; /// SqlQueryInterceptor can track life cycle of a sql query and customize or /// abort its execution at given point. pub trait SqlQueryInterceptor { type Error: ErrorExt; - /// Called before a query string is parsed into sql statements. - /// The implementation is allowed to change the sql string if needed. + /// Called before a query is parsed into statement. + /// The implementation is allowed to change the query if needed. fn pre_parsing<'a>( &self, - query: &'a str, + query: QueryLanguage, _query_ctx: QueryContextRef, - ) -> Result, Self::Error> { - Ok(Cow::Borrowed(query)) + ) -> Result { + Ok(query) } - /// Called after sql is parsed into statements. This interceptor is called - /// on each statement and the implementation can alter the statement or - /// abort execution by raising an error. + /// Called after query is parsed into statement. This interceptor can alter + /// the statement or abort execution by raising an error. fn post_parsing( &self, - statements: Vec, + statement: QueryStatement, _query_ctx: QueryContextRef, - ) -> Result, Self::Error> { - Ok(statements) + ) -> Result { + Ok(statement) } /// Called before sql is actually executed. This hook is not called at the moment. fn pre_execute( &self, - _statement: &Statement, + _statement: &QueryStatement, _plan: Option<&LogicalPlan>, _query_ctx: QueryContextRef, ) -> Result<(), Self::Error> { @@ -77,23 +75,23 @@ where { type Error = E; - fn pre_parsing<'a>( + fn pre_parsing( &self, - query: &'a str, + query: QueryLanguage, query_ctx: QueryContextRef, - ) -> Result, Self::Error> { + ) -> Result { if let Some(this) = self { this.pre_parsing(query, query_ctx) } else { - Ok(Cow::Borrowed(query)) + Ok(query) } } fn post_parsing( &self, - statements: Vec, + statements: QueryStatement, query_ctx: QueryContextRef, - ) -> Result, Self::Error> { + ) -> Result { if let Some(this) = self { this.post_parsing(statements, query_ctx) } else { @@ -103,7 +101,7 @@ where fn pre_execute( &self, - statement: &Statement, + statement: &QueryStatement, plan: Option<&LogicalPlan>, query_ctx: QueryContextRef, ) -> Result<(), Self::Error> { diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index a974d74fd7a0..b8a8fe34374a 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -17,8 +17,8 @@ use std::sync::Arc; use async_trait::async_trait; use common_error::prelude::*; use common_query::Output; +use query::parser::QueryStatement; use session::context::QueryContextRef; -use sql::statements::statement::Statement; use crate::error::{self, Result}; @@ -41,9 +41,9 @@ pub trait SqlQueryHandler { query_ctx: QueryContextRef, ) -> Vec>; - async fn do_statement_query( + async fn statement_query( &self, - stmt: Statement, + stmt: QueryStatement, query_ctx: QueryContextRef, ) -> std::result::Result; @@ -52,11 +52,6 @@ pub trait SqlQueryHandler { catalog: &str, schema: &str, ) -> std::result::Result; - - async fn foo(&self) -> std::result::Result<(), Self::Error> { - println!(""); - unimplemented!() - } } pub struct ServerSqlQueryHandlerAdaptor(SqlQueryHandlerRef); @@ -102,13 +97,13 @@ where .collect() } - async fn do_statement_query( + async fn statement_query( &self, - stmt: Statement, + stmt: QueryStatement, query_ctx: QueryContextRef, ) -> Result { self.0 - .do_statement_query(stmt, query_ctx) + .statement_query(stmt, query_ctx) .await .map_err(BoxedError::new) .context(error::ExecuteStatementSnafu) diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 7425ec03c3d3..819bf9e895b7 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -62,9 +62,9 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( + async fn statement_query( &self, - _stmt: sql::statements::statement::Statement, + _stmt: query::parser::QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 70f8c3e07046..986a26c1f950 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -60,9 +60,9 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( + async fn statement_query( &self, - _stmt: sql::statements::statement::Statement, + _stmt: query::parser::QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index 34157305585e..c924011a1f92 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -85,9 +85,9 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( + async fn statement_query( &self, - _stmt: sql::statements::statement::Statement, + _stmt: query::parser::QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/tests/interceptor.rs b/src/servers/tests/interceptor.rs index 593fa89207f1..d28558a37d1e 100644 --- a/src/servers/tests/interceptor.rs +++ b/src/servers/tests/interceptor.rs @@ -12,29 +12,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::sync::Arc; +use query::parser::QueryLanguage; use servers::error::{self, Result}; use servers::interceptor::SqlQueryInterceptor; use session::context::{QueryContext, QueryContextRef}; -pub struct NoopInterceptor; +pub struct RewriteInterceptor; -impl SqlQueryInterceptor for NoopInterceptor { +impl SqlQueryInterceptor for RewriteInterceptor { type Error = error::Error; - fn pre_parsing<'a>(&self, query: &'a str, _query_ctx: QueryContextRef) -> Result> { - let modified_query = format!("{query};"); - Ok(Cow::Owned(modified_query)) + fn pre_parsing( + &self, + _query: QueryLanguage, + _query_ctx: QueryContextRef, + ) -> Result { + let modified_query = QueryLanguage::Sql("SELECT 1;".to_string()); + Ok(modified_query) } } #[test] fn test_default_interceptor_behaviour() { - let di = NoopInterceptor; + let di = RewriteInterceptor; let ctx = Arc::new(QueryContext::new()); - let query = "SELECT 1"; - assert_eq!("SELECT 1;", di.pre_parsing(query, ctx).unwrap()); + let query = QueryLanguage::Promql("blabla[1m]".to_string()); + assert_eq!( + QueryLanguage::Sql("SELECT 1;".to_string()), + di.pre_parsing(query, ctx).unwrap() + ); } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 566e8f578d5f..3b30ecaa437c 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -76,9 +76,9 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( + async fn statement_query( &self, - _stmt: sql::statements::statement::Statement, + _stmt: query::parser::QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() From fad35f0acc25952f8b428115288c5233b8428f9e Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 6 Feb 2023 22:09:41 +0800 Subject: [PATCH 4/9] change implementations' error type to servers::Error new trait method query and default implementation Signed-off-by: Ruihang Xia --- src/datanode/src/error.rs | 7 ++ src/datanode/src/instance/sql.rs | 42 ++++++--- src/frontend/src/error.rs | 12 ++- src/frontend/src/instance.rs | 113 +++++++++++++---------- src/frontend/src/instance/distributed.rs | 43 ++++++--- src/frontend/src/instance/grpc.rs | 8 +- src/frontend/src/instance/standalone.rs | 7 +- src/servers/src/error.rs | 16 +++- src/servers/src/query_handler/sql.rs | 15 ++- 9 files changed, 177 insertions(+), 86 deletions(-) diff --git a/src/datanode/src/error.rs b/src/datanode/src/error.rs index 4d40f311c73b..f0d22ef6c63a 100644 --- a/src/datanode/src/error.rs +++ b/src/datanode/src/error.rs @@ -252,6 +252,12 @@ pub enum Error { source: catalog::error::Error, }, + #[snafu(display("Failed to find catalog, source: {}", source))] + FindCatalog { + #[snafu(backtrace)] + source: servers::error::Error, + }, + #[snafu(display("Failed to find table {} from catalog, source: {}", table_name, source))] FindTable { table_name: String, @@ -333,6 +339,7 @@ impl ErrorExt for Error { | Error::GetTable { source, .. } | Error::AlterTable { source, .. } => source.status_code(), Error::DropTable { source, .. } => source.status_code(), + Error::FindCatalog { source } => source.status_code(), Error::Insert { source, .. } => source.status_code(), diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 91f4c82c216c..c72d42fadfdb 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -29,7 +29,9 @@ use sql::statements::statement::Statement; use table::engine::TableReference; use table::requests::{CreateDatabaseRequest, DropTableRequest}; -use crate::error::{self, BumpTableIdSnafu, ExecuteSqlSnafu, Result, TableIdProviderNotFoundSnafu}; +use crate::error::{ + self, BumpTableIdSnafu, ExecuteSqlSnafu, FindCatalogSnafu, Result, TableIdProviderNotFoundSnafu, +}; use crate::instance::Instance; use crate::metric; use crate::sql::SqlRequest; @@ -148,7 +150,8 @@ impl Instance { QueryStatement::Sql(Statement::Use(ref schema)) => { let catalog = &query_ctx.current_catalog(); ensure!( - self.is_valid_schema(catalog, schema)?, + self.is_valid_schema(catalog, schema) + .context(FindCatalogSnafu)?, error::DatabaseNotFoundSnafu { catalog, schema } ); @@ -205,12 +208,20 @@ pub fn table_idents_to_full_name( #[async_trait] impl SqlQueryHandler for Instance { - type Error = error::Error; + type Error = server_error::Error; - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Vec> { let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); // we assume sql string has only 1 statement in datanode - let result = self.execute_sql(query, query_ctx).await; + let result = self + .execute_sql(query, query_ctx) + .await + .map_err(BoxedError::new) + .context(server_error::ExecuteQueryStatementSnafu); vec![result] } @@ -218,26 +229,31 @@ impl SqlQueryHandler for Instance { &self, query: &str, query_ctx: QueryContextRef, - ) -> Vec> { - let _timer = timer!(metric::METRIC_HANDLE_PROMQL_ELAPSED); - let result = self.execute_promql(query, query_ctx).await; - vec![result] + ) -> Vec> { + // let _timer = timer!(metric::METRIC_HANDLE_PROMQL_ELAPSED); + // let result = self.execute_promql(query, query_ctx).await; + // vec![result] + todo!() } async fn statement_query( &self, stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> Result { + ) -> server_error::Result { let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - self.execute_stmt(stmt, query_ctx).await + self.execute_stmt(stmt, query_ctx) + .await + .map_err(BoxedError::new) + .context(server_error::ExecuteQueryStatementSnafu) } - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { + fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { self.catalog_manager .schema(catalog, schema) .map(|s| s.is_some()) - .context(error::CatalogSnafu) + .map_err(BoxedError::new) + .context(server_error::CheckDatabaseValiditySnafu) } } diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 3fe8bb203198..a9d38c30060c 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -344,12 +344,18 @@ pub enum Error { }, // TODO(ruihang): merge all query execution error kinds - #[snafu(display("failed to execute PromQL query {}, source: {}", query, source))] + #[snafu(display("Failed to execute PromQL query {}, source: {}", query, source))] ExecutePromql { query: String, #[snafu(backtrace)] source: servers::error::Error, }, + + #[snafu(display("Failed to execute query statement, source: {}", source))] + ExecuteQueryStatement { + #[snafu(backtrace)] + source: BoxedError, + }, } pub type Result = std::result::Result; @@ -424,7 +430,9 @@ impl ErrorExt for Error { Error::InvokeDatanode { source } => source.status_code(), Error::ColumnDefaultValue { source, .. } => source.status_code(), Error::ColumnNoneDefaultValue { .. } => StatusCode::InvalidArguments, - Error::External { source } => source.status_code(), + Error::External { source } | Error::ExecuteQueryStatement { source } => { + source.status_code() + } Error::DeserializePartition { source, .. } | Error::FindTableRoute { source, .. } => { source.status_code() } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 85a157230c58..130c2f64af40 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -57,18 +57,16 @@ use sql::statements::statement::Statement; use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; -use crate::error::{ - self, Error, ExecutePromqlSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, Result, -}; +use crate::error::{self, Error, MissingMetasrvOptsSnafu, Result}; use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory}; use crate::frontend::FrontendOptions; -use crate::instance::standalone::{StandaloneGrpcQueryHandler, StandaloneSqlQueryHandler}; +use crate::instance::standalone::StandaloneGrpcQueryHandler; use crate::Plugins; #[async_trait] pub trait FrontendInstance: GrpcQueryHandler - + SqlQueryHandler + + SqlQueryHandler + OpentsdbProtocolHandler + InfluxdbLineProtocolHandler + PrometheusProtocolHandler @@ -89,7 +87,7 @@ pub struct Instance { /// Script handler is None in distributed mode, only works on standalone mode. script_handler: Option, - sql_handler: SqlQueryHandlerRef, + sql_handler: SqlQueryHandlerRef, grpc_query_handler: GrpcQueryHandlerRef, promql_handler: Option, @@ -167,7 +165,7 @@ impl Instance { catalog_manager: dn_instance.catalog_manager().clone(), script_handler: None, create_expr_factory: Arc::new(DefaultCreateExprFactory), - sql_handler: StandaloneSqlQueryHandler::arc(dn_instance.clone()), + sql_handler: dn_instance.clone(), grpc_query_handler: StandaloneGrpcQueryHandler::arc(dn_instance.clone()), promql_handler: Some(dn_instance.clone()), plugins: Default::default(), @@ -397,19 +395,32 @@ impl Instance { } .fail(), QueryStatement::Sql(Statement::Use(db)) => self.handle_use(db, query_ctx), - _ => self.sql_handler.statement_query(stmt, query_ctx).await, + _ => self + .sql_handler + .statement_query(stmt, query_ctx) + .await + .map_err(BoxedError::new) + .context(error::ExecuteQueryStatementSnafu), } } } #[async_trait] impl SqlQueryHandler for Instance { - type Error = Error; + type Error = server_error::Error; - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Vec> { let query_interceptor = self.plugins.get::>(); let query = QueryLanguage::Sql(query.to_string()); - let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) { + let query = match query_interceptor + .pre_parsing(query, query_ctx.clone()) + .map_err(BoxedError::new) + .context(server_error::ParseQuerySnafu) + { Ok(q) => q, Err(e) => return vec![Err(e)], }; @@ -425,69 +436,71 @@ impl SqlQueryHandler for Instance { Ok(stmts) => { let mut results = Vec::with_capacity(stmts.len()); for stmt in stmts { - // TODO(sunng87): figure out at which stage we can call - // this hook after ArrowFlight adoption. We need to provide - // LogicalPlan as to this hook. - if let Err(e) = query_interceptor.pre_execute(&stmt, None, query_ctx.clone()) { - results.push(Err(e)); - break; - } - match self.query_statement(stmt, query_ctx.clone()).await { - Ok(output) => { - let output_result = - query_interceptor.post_execute(output, query_ctx.clone()); - results.push(output_result); - } - Err(e) => { - results.push(Err(e)); - break; - } - } + let result = self.statement_query(stmt, query_ctx.clone()).await; + results.push(result); } results } Err(e) => { - vec![Err(e)] + vec![Err(e) + .map_err(BoxedError::new) + .context(server_error::ParseQuerySnafu)] } } } - async fn do_promql_query(&self, query: &str, _: QueryContextRef) -> Vec> { - if let Some(handler) = &self.promql_handler { - let result = handler - .do_query(query) - .await - .context(ExecutePromqlSnafu { query }); - vec![result] - } else { - vec![Err(NotSupportedSnafu { - feat: "PromQL Query", - } - .build())] - } + async fn do_promql_query( + &self, + query: &str, + _: QueryContextRef, + ) -> Vec> { + // if let Some(handler) = &self.promql_handler { + // let result = handler + // .do_query(query) + // .await + // .context(ExecutePromqlSnafu { query }); + // vec![result] + // } else { + // vec![Err(NotSupportedSnafu { + // feat: "PromQL Query", + // } + // .build())] + // } + todo!() + // let query = QueryLanguage::Promql(query.to_string()); + + // let statement = QueryLanguageParser::parse(QueryLanguage::Promql(query.to_string())) + // .map_err(BoxedError::new) + // .context(server_error::ParseQuerySnafu)?; } async fn statement_query( &self, stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> Result { + ) -> server_error::Result { let query_interceptor = self.plugins.get::>(); // TODO(sunng87): figure out at which stage we can call // this hook after ArrowFlight adoption. We need to provide // LogicalPlan as to this hook. - query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?; + query_interceptor + .pre_execute(&stmt, None, query_ctx.clone()) + .map_err(BoxedError::new) + .context(server_error::ExecuteQueryStatementSnafu)?; self.query_statement(stmt, query_ctx.clone()) .await .and_then(|output| query_interceptor.post_execute(output, query_ctx.clone())) + .map_err(BoxedError::new) + .context(server_error::ExecuteQueryStatementSnafu) } - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { + fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { self.catalog_manager .schema(catalog, schema) .map(|s| s.is_some()) - .context(error::CatalogSnafu) + .map_err(BoxedError::new) + .context(server_error::CheckDatabaseValiditySnafu) } } @@ -534,6 +547,8 @@ impl PromqlHandler for Instance { mod tests { use std::sync::atomic::AtomicU32; + use common_error::prelude::ErrorExt; + use common_error::status_code::StatusCode; use session::context::QueryContext; use super::*; @@ -791,7 +806,7 @@ mod tests { .await .remove(0) { - assert!(matches!(e, error::Error::NotSupported { .. })); + assert_eq!(e.status_code(), StatusCode::Unsupported); } else { unreachable!(); } @@ -801,7 +816,7 @@ mod tests { .await .remove(0) { - assert!(matches!(e, error::Error::NotSupported { .. })); + assert_eq!(e.status_code(), StatusCode::Unsupported); } else { unreachable!(); } diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 8e2fe525ac1b..a7c12cb8f8d0 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -40,6 +40,7 @@ use partition::partition::{PartitionBound, PartitionDef}; use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; +use servers::error as server_error; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; @@ -54,9 +55,9 @@ use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ self, AlterExprToRequestSnafu, CatalogEntrySerdeSnafu, CatalogNotFoundSnafu, CatalogSnafu, - ColumnDataTypeSnafu, DeserializePartitionSnafu, ParseQuerySnafu, ParseSqlSnafu, - PrimaryKeyNotFoundSnafu, RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, - StartMetaClientSnafu, TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, + ColumnDataTypeSnafu, DeserializePartitionSnafu,ParseSqlSnafu, PrimaryKeyNotFoundSnafu, + RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, StartMetaClientSnafu, + TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, }; use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory}; use crate::sql::insert_to_request; @@ -220,9 +221,14 @@ impl DistInstance { .context(error::ExecuteStatementSnafu) } - async fn handle_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Vec> { + async fn handle_sql( + &self, + sql: &str, + query_ctx: QueryContextRef, + ) -> Vec> { let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql.to_string())) - .context(ParseQuerySnafu); + .map_err(BoxedError::new) + .context(server_error::ParseQuerySnafu); match stmt { Ok(stmt) => { let result = self.statement_query(stmt, query_ctx.clone()).await; @@ -394,9 +400,13 @@ impl DistInstance { #[async_trait] impl SqlQueryHandler for DistInstance { - type Error = error::Error; + type Error = server_error::Error; - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Vec> { self.handle_sql(query, query_ctx).await } @@ -412,18 +422,22 @@ impl SqlQueryHandler for DistInstance { &self, stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> Result { - match stmt { + ) -> server_error::Result { + let result = match stmt { QueryStatement::Sql(stmt) => self.handle_sql_statement(stmt, query_ctx).await, QueryStatement::Promql(_) => self.handle_promql_statement(stmt, query_ctx).await, - } + }; + result + .map_err(BoxedError::new) + .context(server_error::ExecuteQueryStatementSnafu) } - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { + fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { self.catalog_manager .schema(catalog, schema) .map(|s| s.is_some()) - .context(CatalogSnafu) + .map_err(BoxedError::new) + .context(server_error::CheckDatabaseValiditySnafu) } } @@ -626,7 +640,6 @@ mod test { use super::*; use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory}; - use crate::instance::standalone::StandaloneSqlQueryHandler; #[tokio::test] async fn test_parse_partitions() { @@ -755,7 +768,7 @@ ENGINE=mito", .remove(0) .unwrap(); - async fn assert_show_tables(instance: SqlQueryHandlerRef) { + async fn assert_show_tables(instance: SqlQueryHandlerRef) { let sql = "show tables in test_show_tables"; let output = instance .do_query(sql, QueryContext::arc()) @@ -779,7 +792,7 @@ ENGINE=mito", // Asserts that new table is created in Datanode as well. for x in datanode_instances.values() { - assert_show_tables(StandaloneSqlQueryHandler::arc(x.clone())).await + assert_show_tables(x.clone()).await } } } diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index 462c2b5ce3b3..d277f0598e4e 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -15,11 +15,12 @@ use api::v1::greptime_request::Request; use api::v1::query_request::Query; use async_trait::async_trait; +use common_error::prelude::BoxedError; use common_query::Output; use servers::query_handler::grpc::GrpcQueryHandler; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContextRef; -use snafu::{ensure, OptionExt}; +use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{self, Result}; use crate::instance::Instance; @@ -46,7 +47,10 @@ impl GrpcQueryHandler for Instance { feat: "execute multiple statements in SQL query string through GRPC interface" } ); - result.remove(0)? + result + .remove(0) + .map_err(BoxedError::new) + .context(error::ExecuteQueryStatementSnafu)? } Query::LogicalPlan(_) => { return error::NotSupportedSnafu { diff --git a/src/frontend/src/instance/standalone.rs b/src/frontend/src/instance/standalone.rs index e113440ad4a5..b259f9edd51b 100644 --- a/src/frontend/src/instance/standalone.rs +++ b/src/frontend/src/instance/standalone.rs @@ -16,9 +16,11 @@ use std::sync::Arc; use api::v1::greptime_request::Request as GreptimeRequest; use async_trait::async_trait; +use common_error::prelude::BoxedError; use common_query::Output; use datanode::error::Error as DatanodeError; use query::parser::QueryStatement; +use servers::error as server_error; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; use servers::query_handler::sql::{SqlQueryHandler, SqlQueryHandlerRef}; use session::context::QueryContextRef; @@ -59,11 +61,12 @@ impl SqlQueryHandler for StandaloneSqlQueryHandler { &self, stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> Result { + ) -> server_error::Result { self.0 .statement_query(stmt, query_ctx) .await - .context(error::InvokeDatanodeSnafu) + .map_err(BoxedError::new) + .context(server_error::ExecuteQueryStatementSnafu) } fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index caab78c6c116..2b61b93db61d 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -82,6 +82,12 @@ pub enum Error { source: BoxedError, }, + #[snafu(display("Failed to parse query, source: {}", source))] + ParseQuery { + #[snafu(backtrace)] + source: BoxedError, + }, + #[snafu(display("{source}"))] ExecuteGrpcQuery { #[snafu(backtrace)] @@ -94,6 +100,12 @@ pub enum Error { source: BoxedError, }, + #[snafu(display("Failed to execute query statement, source: {}", source))] + ExecuteQueryStatement { + #[snafu(backtrace)] + source: BoxedError, + }, + #[snafu(display("Failed to check database validity, source: {}", source))] CheckDatabaseValidity { #[snafu(backtrace)] @@ -285,7 +297,9 @@ impl ErrorExt for Error { | ExecuteStatement { source, .. } | CheckDatabaseValidity { source, .. } | ExecuteAlter { source, .. } - | PutOpentsdbDataPoint { source, .. } => source.status_code(), + | PutOpentsdbDataPoint { source, .. } + | ParseQuery { source } + | ExecuteQueryStatement { source } => source.status_code(), NotSupported { .. } | InvalidQuery { .. } diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index b8a8fe34374a..ed439cbd0a44 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use async_trait::async_trait; use common_error::prelude::*; use common_query::Output; -use query::parser::QueryStatement; +use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use session::context::QueryContextRef; use crate::error::{self, Result}; @@ -41,11 +41,22 @@ pub trait SqlQueryHandler { query_ctx: QueryContextRef, ) -> Vec>; + /// Execute a query statement. async fn statement_query( &self, stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> std::result::Result; + ) -> Result; + + async fn query(&self, query: QueryLanguage, query_ctx: QueryContextRef) -> Result { + let stmt = QueryLanguageParser::parse(query) + .map_err(BoxedError::new) + .context(error::ParseQuerySnafu)?; + self.statement_query(stmt, query_ctx) + .await + .map_err(BoxedError::new) + .context(error::ExecuteQueryStatementSnafu) + } fn is_valid_schema( &self, From 8f2d70d54f0a9aa9f56e6f33bd73789f00ffda68 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 7 Feb 2023 01:13:29 +0800 Subject: [PATCH 5/9] remove do_query and do_prom_query methods add query and query_multiple, with default implementations Signed-off-by: Ruihang Xia --- src/datanode/src/instance/sql.rs | 28 ---- src/frontend/src/instance.rs | 157 ++++++++++------------ src/frontend/src/instance/distributed.rs | 89 +++--------- src/frontend/src/instance/grpc.rs | 20 +-- src/frontend/src/instance/influxdb.rs | 13 +- src/frontend/src/instance/opentsdb.rs | 8 +- src/frontend/src/instance/prometheus.rs | 18 +-- src/frontend/src/instance/standalone.rs | 33 +---- src/servers/src/http.rs | 15 --- src/servers/src/http/handler.rs | 16 ++- src/servers/src/mysql/handler.rs | 6 +- src/servers/src/postgres/handler.rs | 6 +- src/servers/src/query_handler/sql.rs | 94 +++++-------- src/servers/tests/http/influxdb_test.rs | 16 +-- src/servers/tests/http/opentsdb_test.rs | 14 -- src/servers/tests/http/prometheus_test.rs | 16 +-- src/servers/tests/mod.rs | 32 ++--- src/servers/tests/py_script/mod.rs | 7 +- 18 files changed, 197 insertions(+), 391 deletions(-) diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index c72d42fadfdb..9353db8d5be7 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -208,34 +208,6 @@ pub fn table_idents_to_full_name( #[async_trait] impl SqlQueryHandler for Instance { - type Error = server_error::Error; - - async fn do_query( - &self, - query: &str, - query_ctx: QueryContextRef, - ) -> Vec> { - let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - // we assume sql string has only 1 statement in datanode - let result = self - .execute_sql(query, query_ctx) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQueryStatementSnafu); - vec![result] - } - - async fn do_promql_query( - &self, - query: &str, - query_ctx: QueryContextRef, - ) -> Vec> { - // let _timer = timer!(metric::METRIC_HANDLE_PROMQL_ELAPSED); - // let result = self.execute_promql(query, query_ctx).await; - // vec![result] - todo!() - } - async fn statement_query( &self, stmt: QueryStatement, diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 130c2f64af40..7a0faa211a13 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -66,7 +66,7 @@ use crate::Plugins; #[async_trait] pub trait FrontendInstance: GrpcQueryHandler - + SqlQueryHandler + + SqlQueryHandler + OpentsdbProtocolHandler + InfluxdbLineProtocolHandler + PrometheusProtocolHandler @@ -87,7 +87,7 @@ pub struct Instance { /// Script handler is None in distributed mode, only works on standalone mode. script_handler: Option, - sql_handler: SqlQueryHandlerRef, + sql_handler: SqlQueryHandlerRef, grpc_query_handler: GrpcQueryHandlerRef, promql_handler: Option, @@ -403,21 +403,15 @@ impl Instance { .context(error::ExecuteQueryStatementSnafu), } } -} -#[async_trait] -impl SqlQueryHandler for Instance { - type Error = server_error::Error; - - async fn do_query( + async fn do_query_multiple_sql( &self, - query: &str, + sql: QueryLanguage, query_ctx: QueryContextRef, ) -> Vec> { let query_interceptor = self.plugins.get::>(); - let query = QueryLanguage::Sql(query.to_string()); let query = match query_interceptor - .pre_parsing(query, query_ctx.clone()) + .pre_parsing(sql, query_ctx.clone()) .map_err(BoxedError::new) .context(server_error::ParseQuerySnafu) { @@ -448,30 +442,41 @@ impl SqlQueryHandler for Instance { } } } +} + +#[async_trait] +impl SqlQueryHandler for Instance { + async fn query( + &self, + query: QueryLanguage, + query_ctx: QueryContextRef, + ) -> server_error::Result { + let query_interceptor = self.plugins.get::>(); + let query = query_interceptor + .pre_parsing(query, query_ctx.clone()) + .map_err(BoxedError::new) + .context(server_error::ParseQuerySnafu)?; + let stmt = QueryLanguageParser::parse(query) + .map_err(BoxedError::new) + .context(server_error::ParseQuerySnafu)?; + let stmt = query_interceptor + .post_parsing(stmt, query_ctx.clone()) + .map_err(BoxedError::new) + .context(server_error::ParseQuerySnafu)?; + self.statement_query(stmt, query_ctx).await + } - async fn do_promql_query( + async fn query_multiple( &self, - query: &str, - _: QueryContextRef, + query: QueryLanguage, + query_ctx: QueryContextRef, ) -> Vec> { - // if let Some(handler) = &self.promql_handler { - // let result = handler - // .do_query(query) - // .await - // .context(ExecutePromqlSnafu { query }); - // vec![result] - // } else { - // vec![Err(NotSupportedSnafu { - // feat: "PromQL Query", - // } - // .build())] - // } - todo!() - // let query = QueryLanguage::Promql(query.to_string()); - - // let statement = QueryLanguageParser::parse(QueryLanguage::Promql(query.to_string())) - // .map_err(BoxedError::new) - // .context(server_error::ParseQuerySnafu)?; + match query { + QueryLanguage::Sql(_) => self.do_query_multiple_sql(query, query_ctx).await, + QueryLanguage::Promql(_) => { + vec![self.query(query, query_ctx).await] + } + } } async fn statement_query( @@ -561,7 +566,8 @@ mod tests { let standalone = tests::create_standalone_instance("test_execute_sql").await; let instance = standalone.instance; - let sql = r#"CREATE TABLE demo( + let sql = QueryLanguage::Sql( + r#"CREATE TABLE demo( host STRING, ts TIMESTAMP, cpu DOUBLE NULL, @@ -569,35 +575,31 @@ mod tests { disk_util DOUBLE DEFAULT 9.9, TIME INDEX (ts), PRIMARY KEY(host) - ) engine=mito with(regions=1);"#; - let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) - .await - .remove(0) - .unwrap(); + ) engine=mito with(regions=1);"# + .to_string(), + ); + let output = instance.query(sql, query_ctx.clone()).await.unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 0), _ => unreachable!(), } - let sql = r#"insert into demo(host, cpu, memory, ts) values + let sql = QueryLanguage::Sql( + r#"insert into demo(host, cpu, memory, ts) values ('frontend.host1', 1.1, 100, 1000), ('frontend.host2', null, null, 2000), ('frontend.host3', 3.3, 300, 3000) - "#; - let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) - .await - .remove(0) - .unwrap(); + "# + .to_string(), + ); + let output = instance.query(sql, query_ctx.clone()).await.unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 3), _ => unreachable!(), } - let sql = "select * from demo"; - let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) - .await - .remove(0) - .unwrap(); + let sql = QueryLanguage::Sql("select * from demo".to_string()); + let output = instance.query(sql, query_ctx.clone()).await.unwrap(); match output { Output::RecordBatches(_) => { unreachable!("Output::RecordBatches"); @@ -621,11 +623,10 @@ mod tests { } }; - let sql = "select * from demo where ts>cast(1000000000 as timestamp)"; // use nanoseconds as where condition - let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) - .await - .remove(0) - .unwrap(); + let sql = QueryLanguage::Sql( + "select * from demo where ts>cast(1000000000 as timestamp)".to_string(), + ); + let output = instance.query(sql, query_ctx.clone()).await.unwrap(); match output { Output::RecordBatches(_) => { unreachable!("Output::RecordBatches") @@ -723,7 +724,8 @@ mod tests { plugins.insert::>(counter_hook.clone()); Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins)); - let sql = r#"CREATE TABLE demo( + let sql = QueryLanguage::Sql( + r#"CREATE TABLE demo( host STRING, ts TIMESTAMP, cpu DOUBLE NULL, @@ -731,11 +733,10 @@ mod tests { disk_util DOUBLE DEFAULT 9.9, TIME INDEX (ts), PRIMARY KEY(host) - ) engine=mito with(regions=1);"#; - let output = SqlQueryHandler::do_query(&*instance, sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + ) engine=mito with(regions=1);"# + .to_string(), + ); + let output = instance.query(sql, QueryContext::arc()).await.unwrap(); // assert that the hook is called 3 times assert_eq!(4, counter_hook.c.load(std::sync::atomic::Ordering::Relaxed)); @@ -782,7 +783,8 @@ mod tests { plugins.insert::>(hook.clone()); Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins)); - let sql = r#"CREATE TABLE demo( + let sql = QueryLanguage::Sql( + r#"CREATE TABLE demo( host STRING, ts TIMESTAMP, cpu DOUBLE NULL, @@ -790,35 +792,22 @@ mod tests { disk_util DOUBLE DEFAULT 9.9, TIME INDEX (ts), PRIMARY KEY(host) - ) engine=mito with(regions=1);"#; - let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) - .await - .remove(0) - .unwrap(); + ) engine=mito with(regions=1);"# + .to_string(), + ); + let output = instance.query(sql, query_ctx.clone()).await.unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 0), _ => unreachable!(), } - let sql = r#"CREATE DATABASE tomcat"#; - if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) - .await - .remove(0) - { - assert_eq!(e.status_code(), StatusCode::Unsupported); - } else { - unreachable!(); - } + let sql = QueryLanguage::Sql(r#"CREATE DATABASE tomcat"#.to_string()); + let e = instance.query(sql, query_ctx.clone()).await.unwrap_err(); + assert_eq!(e.status_code(), StatusCode::Unsupported); - let sql = r#"SHOW DATABASES"#; - if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) - .await - .remove(0) - { - assert_eq!(e.status_code(), StatusCode::Unsupported); - } else { - unreachable!(); - } + let sql = QueryLanguage::Sql(r#"SHOW DATABASES"#.to_string()); + let e = instance.query(sql, query_ctx).await.unwrap_err(); + assert_eq!(e.status_code(), StatusCode::Unsupported); } } diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index a7c12cb8f8d0..2ce56e6b342e 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -37,7 +37,7 @@ use meta_client::rpc::{ TableName, TableRoute, }; use partition::partition::{PartitionBound, PartitionDef}; -use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; +use query::parser::QueryStatement; use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error as server_error; @@ -55,7 +55,7 @@ use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ self, AlterExprToRequestSnafu, CatalogEntrySerdeSnafu, CatalogNotFoundSnafu, CatalogSnafu, - ColumnDataTypeSnafu, DeserializePartitionSnafu,ParseSqlSnafu, PrimaryKeyNotFoundSnafu, + ColumnDataTypeSnafu, DeserializePartitionSnafu, ParseSqlSnafu, PrimaryKeyNotFoundSnafu, RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, StartMetaClientSnafu, TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, }; @@ -221,25 +221,6 @@ impl DistInstance { .context(error::ExecuteStatementSnafu) } - async fn handle_sql( - &self, - sql: &str, - query_ctx: QueryContextRef, - ) -> Vec> { - let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(sql.to_string())) - .map_err(BoxedError::new) - .context(server_error::ParseQuerySnafu); - match stmt { - Ok(stmt) => { - let result = self.statement_query(stmt, query_ctx.clone()).await; - vec![result] - } - - // results - Err(e) => vec![Err(e)], - } - } - /// Handles distributed database creation async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result { let key = SchemaKey { @@ -400,24 +381,6 @@ impl DistInstance { #[async_trait] impl SqlQueryHandler for DistInstance { - type Error = server_error::Error; - - async fn do_query( - &self, - query: &str, - query_ctx: QueryContextRef, - ) -> Vec> { - self.handle_sql(query, query_ctx).await - } - - async fn do_promql_query( - &self, - _: &str, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - async fn statement_query( &self, stmt: QueryStatement, @@ -632,6 +595,7 @@ fn find_partition_columns( #[cfg(test)] mod test { use itertools::Itertools; + use query::parser::QueryLanguage; use servers::query_handler::sql::SqlQueryHandlerRef; use session::context::QueryContext; use sql::dialect::GenericDialect; @@ -690,23 +654,15 @@ ENGINE=mito", let instance = crate::tests::create_distributed_instance("test_show_databases").await; let dist_instance = &instance.dist_instance; - let sql = "create database test_show_databases"; - let output = dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + let sql = QueryLanguage::Sql("create database test_show_databases".to_string()); + let output = dist_instance.query(sql, QueryContext::arc()).await.unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 1), _ => unreachable!(), } - let sql = "show databases"; - let output = dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + let sql = QueryLanguage::Sql("show databases".to_string()); + let output = dist_instance.query(sql, QueryContext::arc()).await.unwrap(); match output { Output::RecordBatches(r) => { let expected1 = vec![ @@ -742,14 +698,11 @@ ENGINE=mito", let dist_instance = &instance.dist_instance; let datanode_instances = instance.datanodes; - let sql = "create database test_show_tables"; - dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + let sql = QueryLanguage::Sql("create database test_show_tables".to_string()); + dist_instance.query(sql, QueryContext::arc()).await.unwrap(); - let sql = " + let sql = QueryLanguage::Sql( + " CREATE TABLE greptime.test_show_tables.dist_numbers ( ts BIGINT, n INT, @@ -761,20 +714,14 @@ ENGINE=mito", PARTITION r2 VALUES LESS THAN (50), PARTITION r3 VALUES LESS THAN (MAXVALUE), ) - ENGINE=mito"; - dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + ENGINE=mito" + .to_string(), + ); + dist_instance.query(sql, QueryContext::arc()).await.unwrap(); - async fn assert_show_tables(instance: SqlQueryHandlerRef) { - let sql = "show tables in test_show_tables"; - let output = instance - .do_query(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + async fn assert_show_tables(instance: SqlQueryHandlerRef) { + let query = QueryLanguage::Sql("show tables in test_show_tables".to_string()); + let output = instance.query(query, QueryContext::arc()).await.unwrap(); match output { Output::RecordBatches(r) => { let expected = r#"+--------------+ diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index d277f0598e4e..4656eaee848c 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -17,10 +17,11 @@ use api::v1::query_request::Query; use async_trait::async_trait; use common_error::prelude::BoxedError; use common_query::Output; +use query::parser::QueryLanguage; use servers::query_handler::grpc::GrpcQueryHandler; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContextRef; -use snafu::{ensure, OptionExt, ResultExt}; +use snafu::{OptionExt, ResultExt}; use crate::error::{self, Result}; use crate::instance::Instance; @@ -39,19 +40,10 @@ impl GrpcQueryHandler for Instance { err_msg: "Missing field 'QueryRequest.query'", })?; match query { - Query::Sql(sql) => { - let mut result = SqlQueryHandler::do_query(self, &sql, ctx).await; - ensure!( - result.len() == 1, - error::NotSupportedSnafu { - feat: "execute multiple statements in SQL query string through GRPC interface" - } - ); - result - .remove(0) - .map_err(BoxedError::new) - .context(error::ExecuteQueryStatementSnafu)? - } + Query::Sql(sql) => SqlQueryHandler::query(self, QueryLanguage::Sql(sql), ctx) + .await + .map_err(BoxedError::new) + .context(error::ExecuteQueryStatementSnafu)?, Query::LogicalPlan(_) => { return error::NotSupportedSnafu { feat: "Execute LogicalPlan in Frontend", diff --git a/src/frontend/src/instance/influxdb.rs b/src/frontend/src/instance/influxdb.rs index 1da96e2143f6..8665d1826b83 100644 --- a/src/frontend/src/instance/influxdb.rs +++ b/src/frontend/src/instance/influxdb.rs @@ -43,6 +43,7 @@ mod test { use common_query::Output; use common_recordbatch::RecordBatches; + use query::parser::QueryLanguage; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContext; @@ -77,13 +78,15 @@ monitor1,host=host2 memory=1027 1663840496400340001"; }; instance.exec(&request, QueryContext::arc()).await.unwrap(); - let mut output = instance - .do_query( - "SELECT ts, host, cpu, memory FROM monitor1 ORDER BY ts", + let output = instance + .query( + QueryLanguage::Sql( + "SELECT ts, host, cpu, memory FROM monitor1 ORDER BY ts".to_string(), + ), QueryContext::arc(), ) - .await; - let output = output.remove(0).unwrap(); + .await + .unwrap(); let Output::Stream(stream) = output else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); assert_eq!( diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index f72c7cbdcd0c..116e2723ab5c 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -43,6 +43,7 @@ mod tests { use common_query::Output; use common_recordbatch::RecordBatches; use itertools::Itertools; + use query::parser::QueryLanguage; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContext; @@ -98,12 +99,13 @@ mod tests { assert!(result.is_ok()); let output = instance - .do_query( - "select * from my_metric_1 order by greptime_timestamp", + .query( + QueryLanguage::Sql( + "select * from my_metric_1 order by greptime_timestamp".to_string(), + ), Arc::new(QueryContext::new()), ) .await - .remove(0) .unwrap(); match output { Output::Stream(stream) => { diff --git a/src/frontend/src/instance/prometheus.rs b/src/frontend/src/instance/prometheus.rs index de7d38bd2c37..cbe2e91bc354 100644 --- a/src/frontend/src/instance/prometheus.rs +++ b/src/frontend/src/instance/prometheus.rs @@ -161,6 +161,7 @@ mod tests { use api::prometheus::remote::label_matcher::Type as MatcherType; use api::prometheus::remote::{Label, LabelMatcher, Sample}; use common_catalog::consts::DEFAULT_CATALOG_NAME; + use query::parser::QueryLanguage; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContext; @@ -194,15 +195,14 @@ mod tests { let db = "prometheus"; let ctx = Arc::new(QueryContext::with(DEFAULT_CATALOG_NAME, db)); - assert!(SqlQueryHandler::do_query( - instance.as_ref(), - "CREATE DATABASE IF NOT EXISTS prometheus", - ctx.clone(), - ) - .await - .get(0) - .unwrap() - .is_ok()); + let _ = instance + .as_ref() + .query( + QueryLanguage::Sql("CREATE DATABASE IF NOT EXISTS prometheus".to_string()), + ctx.clone(), + ) + .await + .unwrap(); instance.write(write_request, ctx.clone()).await.unwrap(); diff --git a/src/frontend/src/instance/standalone.rs b/src/frontend/src/instance/standalone.rs index b259f9edd51b..2562c1f39473 100644 --- a/src/frontend/src/instance/standalone.rs +++ b/src/frontend/src/instance/standalone.rs @@ -28,35 +28,10 @@ use snafu::ResultExt; use crate::error::{self, Result}; -pub(crate) struct StandaloneSqlQueryHandler(SqlQueryHandlerRef); - -impl StandaloneSqlQueryHandler { - pub(crate) fn arc(handler: SqlQueryHandlerRef) -> Arc { - Arc::new(Self(handler)) - } -} +pub(crate) struct StandaloneSqlQueryHandler(SqlQueryHandlerRef); #[async_trait] impl SqlQueryHandler for StandaloneSqlQueryHandler { - type Error = error::Error; - - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - self.0 - .do_query(query, query_ctx) - .await - .into_iter() - .map(|x| x.context(error::InvokeDatanodeSnafu)) - .collect() - } - - async fn do_promql_query( - &self, - _: &str, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - async fn statement_query( &self, stmt: QueryStatement, @@ -69,10 +44,8 @@ impl SqlQueryHandler for StandaloneSqlQueryHandler { .context(server_error::ExecuteQueryStatementSnafu) } - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { - self.0 - .is_valid_schema(catalog, schema) - .context(error::InvokeDatanodeSnafu) + fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { + self.0.is_valid_schema(catalog, schema) } } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 237f44d9c945..62cd647dbd45 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -567,7 +567,6 @@ mod test { use tokio::sync::mpsc; use super::*; - use crate::error::Error; use crate::query_handler::sql::{ServerSqlQueryHandlerAdaptor, SqlQueryHandler}; struct DummyInstance { @@ -576,20 +575,6 @@ mod test { #[async_trait] impl SqlQueryHandler for DummyInstance { - type Error = Error; - - async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { - unimplemented!() - } - - async fn do_promql_query( - &self, - _: &str, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - async fn statement_query( &self, _stmt: query::parser::QueryStatement, diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 0598fe03935e..651b33595301 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -20,6 +20,7 @@ use axum::extract::{Json, Query, State}; use axum::Extension; use common_error::status_code::StatusCode; use common_telemetry::metric; +use query::parser::QueryLanguage; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use session::context::UserInfo; @@ -45,7 +46,12 @@ pub async fn sql( let resp = if let Some(sql) = ¶ms.sql { match super::query_context_from_db(sql_handler.clone(), params.db) { Ok(query_ctx) => { - JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await + JsonResponse::from_output( + sql_handler + .query_multiple(QueryLanguage::Sql(sql.clone()), query_ctx) + .await, + ) + .await } Err(resp) => resp, } @@ -76,8 +82,12 @@ pub async fn promql( let start = Instant::now(); let resp = match super::query_context_from_db(sql_handler.clone(), None) { Ok(query_ctx) => { - JsonResponse::from_output(sql_handler.do_promql_query(¶ms.query, query_ctx).await) - .await + JsonResponse::from_output( + sql_handler + .query_multiple(QueryLanguage::Promql(params.query), query_ctx) + .await, + ) + .await } Err(resp) => resp, }; diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 284b801dbb75..3e0a317f389e 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -22,6 +22,7 @@ use common_telemetry::{error, trace}; use opensrv_mysql::{ AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter, }; +use query::parser::QueryLanguage; use rand::RngCore; use session::context::Channel; use session::Session; @@ -80,7 +81,10 @@ impl MysqlInstanceShim { vec![Ok(output)] } else { self.query_handler - .do_query(query, self.session.context()) + .query_multiple( + QueryLanguage::Sql(query.to_string()), + self.session.context(), + ) .await }; diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index c147b241207f..89b4a56dcaa1 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -29,6 +29,7 @@ use pgwire::api::stmt::NoopQueryParser; use pgwire::api::store::MemPortalStore; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use query::parser::QueryLanguage; use super::PostgresServerHandler; use crate::error::{self, Error, Result}; @@ -41,7 +42,10 @@ impl SimpleQueryHandler for PostgresServerHandler { { let outputs = self .query_handler - .do_query(query, self.query_ctx.clone()) + .query_multiple( + QueryLanguage::Sql(query.to_string()), + self.query_ctx.clone(), + ) .await; let mut results = Vec::with_capacity(outputs.len()); diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index ed439cbd0a44..7f5312451c2c 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -22,25 +22,11 @@ use session::context::QueryContextRef; use crate::error::{self, Result}; -pub type SqlQueryHandlerRef = Arc + Send + Sync>; -pub type ServerSqlQueryHandlerRef = SqlQueryHandlerRef; +pub type SqlQueryHandlerRef = Arc; +pub type ServerSqlQueryHandlerRef = SqlQueryHandlerRef; #[async_trait] pub trait SqlQueryHandler { - type Error: ErrorExt; - - async fn do_query( - &self, - query: &str, - query_ctx: QueryContextRef, - ) -> Vec>; - - async fn do_promql_query( - &self, - query: &str, - query_ctx: QueryContextRef, - ) -> Vec>; - /// Execute a query statement. async fn statement_query( &self, @@ -58,56 +44,50 @@ pub trait SqlQueryHandler { .context(error::ExecuteQueryStatementSnafu) } - fn is_valid_schema( + async fn query_multiple( &self, - catalog: &str, - schema: &str, - ) -> std::result::Result; + query: QueryLanguage, + query_ctx: QueryContextRef, + ) -> Vec> { + match query { + QueryLanguage::Sql(_) => { + let stmts = QueryLanguageParser::parse_multiple(query) + .map_err(BoxedError::new) + .context(error::ParseQuerySnafu); + if let Err(e) = stmts { + return vec![Err(e)]; + } + + let stmts = stmts.unwrap(); + let mut outputs = Vec::with_capacity(stmts.len()); + for stmt in stmts { + let output = self + .statement_query(stmt, query_ctx.clone()) + .await + .map_err(BoxedError::new) + .context(error::ExecuteQueryStatementSnafu); + outputs.push(output); + } + + outputs + } + QueryLanguage::Promql(_) => vec![self.query(query, query_ctx).await], + } + } + + fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result; } -pub struct ServerSqlQueryHandlerAdaptor(SqlQueryHandlerRef); +pub struct ServerSqlQueryHandlerAdaptor(SqlQueryHandlerRef); -impl ServerSqlQueryHandlerAdaptor { - pub fn arc(handler: SqlQueryHandlerRef) -> Arc { +impl ServerSqlQueryHandlerAdaptor { + pub fn arc(handler: SqlQueryHandlerRef) -> Arc { Arc::new(Self(handler)) } } #[async_trait] -impl SqlQueryHandler for ServerSqlQueryHandlerAdaptor -where - E: ErrorExt + Send + Sync + 'static, -{ - type Error = error::Error; - - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - self.0 - .do_query(query, query_ctx) - .await - .into_iter() - .map(|x| { - x.map_err(BoxedError::new) - .context(error::ExecuteQuerySnafu { query }) - }) - .collect() - } - - async fn do_promql_query( - &self, - query: &str, - query_ctx: QueryContextRef, - ) -> Vec> { - self.0 - .do_promql_query(query, query_ctx) - .await - .into_iter() - .map(|x| { - x.map_err(BoxedError::new) - .context(error::ExecuteQuerySnafu { query }) - }) - .collect() - } - +impl SqlQueryHandler for ServerSqlQueryHandlerAdaptor { async fn statement_query( &self, stmt: QueryStatement, diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 819bf9e895b7..5ddfa757e976 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -19,7 +19,7 @@ use async_trait::async_trait; use axum::{http, Router}; use axum_test_helper::TestClient; use common_query::Output; -use servers::error::{Error, Result}; +use servers::error::Result; use servers::http::{HttpOptions, HttpServer}; use servers::influxdb::InfluxdbRequest; use servers::query_handler::sql::SqlQueryHandler; @@ -48,20 +48,6 @@ impl InfluxdbLineProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - type Error = Error; - - async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { - unimplemented!() - } - - async fn do_promql_query( - &self, - _: &str, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - async fn statement_query( &self, _stmt: query::parser::QueryStatement, diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 986a26c1f950..cd18138cdf93 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -46,20 +46,6 @@ impl OpentsdbProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - type Error = error::Error; - - async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { - unimplemented!() - } - - async fn do_promql_query( - &self, - _: &str, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - async fn statement_query( &self, _stmt: query::parser::QueryStatement, diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index c924011a1f92..a06425b60434 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -22,7 +22,7 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use prost::Message; -use servers::error::{Error, Result}; +use servers::error::Result; use servers::http::{HttpOptions, HttpServer}; use servers::prometheus; use servers::prometheus::{snappy_compress, Metrics}; @@ -71,20 +71,6 @@ impl PrometheusProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - type Error = Error; - - async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { - unimplemented!() - } - - async fn do_promql_query( - &self, - _: &str, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - async fn statement_query( &self, _stmt: query::parser::QueryStatement, diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 3b30ecaa437c..b860ea1d9688 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -20,11 +20,11 @@ use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaPr use catalog::{CatalogList, CatalogProvider, SchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; -use query::parser::{QueryLanguage, QueryLanguageParser}; +use query::parser::QueryStatement; use query::{QueryEngineFactory, QueryEngineRef}; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; -use servers::error::{Error, Result}; +use servers::error::Result; use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler}; use servers::query_handler::{ScriptHandler, ScriptHandlerRef}; use session::context::QueryContextRef; @@ -56,32 +56,16 @@ impl DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - type Error = Error; - - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - let stmt = QueryLanguageParser::parse(QueryLanguage::Sql(query.to_owned())).unwrap(); + async fn statement_query( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> Result { let plan = self .query_engine .statement_to_plan(stmt, query_ctx) .unwrap(); - let output = self.query_engine.execute(&plan).await.unwrap(); - vec![Ok(output)] - } - - async fn do_promql_query( - &self, - _: &str, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - - async fn statement_query( - &self, - _stmt: query::parser::QueryStatement, - _query_ctx: QueryContextRef, - ) -> Result { - unimplemented!() + Ok(self.query_engine.execute(&plan).await.unwrap()) } fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { diff --git a/src/servers/tests/py_script/mod.rs b/src/servers/tests/py_script/mod.rs index 3d9d9226b42b..47560c5203e5 100644 --- a/src/servers/tests/py_script/mod.rs +++ b/src/servers/tests/py_script/mod.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use query::parser::QueryLanguage; use servers::error::Result; use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::ScriptHandler; @@ -35,9 +36,11 @@ def double_that(col)->vector[u32]: "#; instance.insert_script("double_that", src).await?; let res = instance - .do_query("select double_that(uint32s) from numbers", query_ctx) + .query( + QueryLanguage::Sql("select double_that(uint32s) from numbers".to_string()), + query_ctx, + ) .await - .remove(0) .unwrap(); match res { common_query::Output::AffectedRows(_) => (), From 2de86ab022f6e79d32402771ccfa5676c0e1bfb8 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 7 Feb 2023 01:26:46 +0800 Subject: [PATCH 6/9] rename SqlHandler/SqlQueryHandler to QueryHandler Signed-off-by: Ruihang Xia --- src/datanode/src/instance/sql.rs | 4 ++-- src/datanode/src/server.rs | 4 ++-- src/frontend/src/instance.rs | 8 ++++---- src/frontend/src/instance/distributed.rs | 8 ++++---- src/frontend/src/instance/grpc.rs | 4 ++-- src/frontend/src/instance/influxdb.rs | 2 +- src/frontend/src/instance/opentsdb.rs | 2 +- src/frontend/src/instance/prometheus.rs | 2 +- src/frontend/src/instance/standalone.rs | 6 +++--- src/frontend/src/server.rs | 8 ++++---- src/servers/src/http.rs | 20 +++++++++---------- src/servers/src/http/handler.rs | 12 +++++------ src/servers/src/mysql/handler.rs | 6 +++--- src/servers/src/mysql/server.rs | 8 ++++---- src/servers/src/postgres.rs | 6 +++--- src/servers/src/postgres/auth_handler.rs | 4 ++-- src/servers/src/postgres/server.rs | 4 ++-- src/servers/src/query_handler/sql.rs | 22 +++++++++++---------- src/servers/tests/http/http_handler_test.rs | 14 ++++++------- src/servers/tests/http/influxdb_test.rs | 4 ++-- src/servers/tests/http/opentsdb_test.rs | 4 ++-- src/servers/tests/http/prometheus_test.rs | 4 ++-- src/servers/tests/mod.rs | 6 +++--- src/servers/tests/py_script/mod.rs | 2 +- tests-integration/src/test_util.rs | 6 +++--- 25 files changed, 86 insertions(+), 84 deletions(-) diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 9353db8d5be7..86a4910a58af 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -21,7 +21,7 @@ use common_telemetry::timer; use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use servers::error as server_error; use servers::promql::PromqlHandler; -use servers::query_handler::sql::SqlQueryHandler; +use servers::query_handler::sql::QueryHandler; use session::context::{QueryContext, QueryContextRef}; use snafu::prelude::*; use sql::ast::ObjectName; @@ -207,7 +207,7 @@ pub fn table_idents_to_full_name( } #[async_trait] -impl SqlQueryHandler for Instance { +impl QueryHandler for Instance { async fn statement_query( &self, stmt: QueryStatement, diff --git a/src/datanode/src/server.rs b/src/datanode/src/server.rs index 3827138fb34b..0c3dcba3e0ad 100644 --- a/src/datanode/src/server.rs +++ b/src/datanode/src/server.rs @@ -22,7 +22,7 @@ use servers::error::Error::InternalIo; use servers::grpc::GrpcServer; use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor; -use servers::query_handler::sql::ServerSqlQueryHandlerAdaptor; +use servers::query_handler::sql::ServerQueryHandlerAdaptor; use servers::server::Server; use servers::tls::TlsOption; use servers::Mode; @@ -70,7 +70,7 @@ impl Services { Some(MysqlServer::create_server( mysql_io_runtime, Arc::new(MysqlSpawnRef::new( - ServerSqlQueryHandlerAdaptor::arc(instance.clone()), + ServerQueryHandlerAdaptor::arc(instance.clone()), None, )), Arc::new(MysqlSpawnConfig::new( diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 7a0faa211a13..cdd14056ca88 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -46,7 +46,7 @@ use servers::error as server_error; use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef}; use servers::promql::{PromqlHandler, PromqlHandlerRef}; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; -use servers::query_handler::sql::{SqlQueryHandler, SqlQueryHandlerRef}; +use servers::query_handler::sql::{QueryHandler, QueryHandlerRef}; use servers::query_handler::{ InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, @@ -66,7 +66,7 @@ use crate::Plugins; #[async_trait] pub trait FrontendInstance: GrpcQueryHandler - + SqlQueryHandler + + QueryHandler + OpentsdbProtocolHandler + InfluxdbLineProtocolHandler + PrometheusProtocolHandler @@ -87,7 +87,7 @@ pub struct Instance { /// Script handler is None in distributed mode, only works on standalone mode. script_handler: Option, - sql_handler: SqlQueryHandlerRef, + sql_handler: QueryHandlerRef, grpc_query_handler: GrpcQueryHandlerRef, promql_handler: Option, @@ -445,7 +445,7 @@ impl Instance { } #[async_trait] -impl SqlQueryHandler for Instance { +impl QueryHandler for Instance { async fn query( &self, query: QueryLanguage, diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 2ce56e6b342e..41235b4c5a0c 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -41,7 +41,7 @@ use query::parser::QueryStatement; use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error as server_error; -use servers::query_handler::sql::SqlQueryHandler; +use servers::query_handler::sql::QueryHandler; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::Value as SqlValue; @@ -380,7 +380,7 @@ impl DistInstance { } #[async_trait] -impl SqlQueryHandler for DistInstance { +impl QueryHandler for DistInstance { async fn statement_query( &self, stmt: QueryStatement, @@ -596,7 +596,7 @@ fn find_partition_columns( mod test { use itertools::Itertools; use query::parser::QueryLanguage; - use servers::query_handler::sql::SqlQueryHandlerRef; + use servers::query_handler::sql::QueryHandlerRef; use session::context::QueryContext; use sql::dialect::GenericDialect; use sql::parser::ParserContext; @@ -719,7 +719,7 @@ ENGINE=mito", ); dist_instance.query(sql, QueryContext::arc()).await.unwrap(); - async fn assert_show_tables(instance: SqlQueryHandlerRef) { + async fn assert_show_tables(instance: QueryHandlerRef) { let query = QueryLanguage::Sql("show tables in test_show_tables".to_string()); let output = instance.query(query, QueryContext::arc()).await.unwrap(); match output { diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index 4656eaee848c..d7bd13eead8a 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -19,7 +19,7 @@ use common_error::prelude::BoxedError; use common_query::Output; use query::parser::QueryLanguage; use servers::query_handler::grpc::GrpcQueryHandler; -use servers::query_handler::sql::SqlQueryHandler; +use servers::query_handler::sql::QueryHandler; use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; @@ -40,7 +40,7 @@ impl GrpcQueryHandler for Instance { err_msg: "Missing field 'QueryRequest.query'", })?; match query { - Query::Sql(sql) => SqlQueryHandler::query(self, QueryLanguage::Sql(sql), ctx) + Query::Sql(sql) => QueryHandler::query(self, QueryLanguage::Sql(sql), ctx) .await .map_err(BoxedError::new) .context(error::ExecuteQueryStatementSnafu)?, diff --git a/src/frontend/src/instance/influxdb.rs b/src/frontend/src/instance/influxdb.rs index 8665d1826b83..c545eebb2c35 100644 --- a/src/frontend/src/instance/influxdb.rs +++ b/src/frontend/src/instance/influxdb.rs @@ -44,7 +44,7 @@ mod test { use common_query::Output; use common_recordbatch::RecordBatches; use query::parser::QueryLanguage; - use servers::query_handler::sql::SqlQueryHandler; + use servers::query_handler::sql::QueryHandler; use session::context::QueryContext; use super::*; diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index 116e2723ab5c..9512c6015532 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -44,7 +44,7 @@ mod tests { use common_recordbatch::RecordBatches; use itertools::Itertools; use query::parser::QueryLanguage; - use servers::query_handler::sql::SqlQueryHandler; + use servers::query_handler::sql::QueryHandler; use session::context::QueryContext; use super::*; diff --git a/src/frontend/src/instance/prometheus.rs b/src/frontend/src/instance/prometheus.rs index cbe2e91bc354..8611ac0acf89 100644 --- a/src/frontend/src/instance/prometheus.rs +++ b/src/frontend/src/instance/prometheus.rs @@ -162,7 +162,7 @@ mod tests { use api::prometheus::remote::{Label, LabelMatcher, Sample}; use common_catalog::consts::DEFAULT_CATALOG_NAME; use query::parser::QueryLanguage; - use servers::query_handler::sql::SqlQueryHandler; + use servers::query_handler::sql::QueryHandler; use session::context::QueryContext; use super::*; diff --git a/src/frontend/src/instance/standalone.rs b/src/frontend/src/instance/standalone.rs index 2562c1f39473..338a78db3b42 100644 --- a/src/frontend/src/instance/standalone.rs +++ b/src/frontend/src/instance/standalone.rs @@ -22,16 +22,16 @@ use datanode::error::Error as DatanodeError; use query::parser::QueryStatement; use servers::error as server_error; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; -use servers::query_handler::sql::{SqlQueryHandler, SqlQueryHandlerRef}; +use servers::query_handler::sql::{QueryHandler, QueryHandlerRef}; use session::context::QueryContextRef; use snafu::ResultExt; use crate::error::{self, Result}; -pub(crate) struct StandaloneSqlQueryHandler(SqlQueryHandlerRef); +pub(crate) struct StandaloneSqlQueryHandler(QueryHandlerRef); #[async_trait] -impl SqlQueryHandler for StandaloneSqlQueryHandler { +impl QueryHandler for StandaloneSqlQueryHandler { async fn statement_query( &self, stmt: QueryStatement, diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index daff47ed26fb..c0a81b52c6ec 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -26,7 +26,7 @@ use servers::opentsdb::OpentsdbServer; use servers::postgres::PostgresServer; use servers::promql::PromqlServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor; -use servers::query_handler::sql::ServerSqlQueryHandlerAdaptor; +use servers::query_handler::sql::ServerQueryHandlerAdaptor; use servers::server::Server; use snafu::ResultExt; use tokio::try_join; @@ -87,7 +87,7 @@ impl Services { let mysql_server = MysqlServer::create_server( mysql_io_runtime, Arc::new(MysqlSpawnRef::new( - ServerSqlQueryHandlerAdaptor::arc(instance.clone()), + ServerQueryHandlerAdaptor::arc(instance.clone()), user_provider.clone(), )), Arc::new(MysqlSpawnConfig::new( @@ -119,7 +119,7 @@ impl Services { ); let pg_server = Box::new(PostgresServer::new( - ServerSqlQueryHandlerAdaptor::arc(instance.clone()), + ServerQueryHandlerAdaptor::arc(instance.clone()), opts.tls.clone(), pg_io_runtime, user_provider.clone(), @@ -152,7 +152,7 @@ impl Services { let http_addr = parse_addr(&http_options.addr)?; let mut http_server = HttpServer::new( - ServerSqlQueryHandlerAdaptor::arc(instance.clone()), + ServerQueryHandlerAdaptor::arc(instance.clone()), http_options.clone(), ); if let Some(user_provider) = user_provider.clone() { diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 62cd647dbd45..4ef69fad64ef 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -53,7 +53,7 @@ use self::authorize::HttpAuth; use self::influxdb::influxdb_write; use crate::auth::UserProviderRef; use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu}; -use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::query_handler::sql::ServerQueryHandlerRef; use crate::query_handler::{ InfluxdbLineProtocolHandlerRef, OpentsdbProtocolHandlerRef, PrometheusProtocolHandlerRef, ScriptHandlerRef, @@ -63,7 +63,7 @@ use crate::server::Server; /// create query context from database name information, catalog and schema are /// resolved from the name pub(crate) fn query_context_from_db( - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, db: Option, ) -> std::result::Result, JsonResponse> { if let Some(db) = &db { @@ -89,7 +89,7 @@ pub const HTTP_API_VERSION: &str = "v1"; pub const HTTP_API_PREFIX: &str = "/v1/"; pub struct HttpServer { - sql_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, options: HttpOptions, influxdb_handler: Option, opentsdb_handler: Option, @@ -338,14 +338,14 @@ async fn serve_docs() -> Html { #[derive(Clone)] pub struct ApiState { - pub sql_handler: ServerSqlQueryHandlerRef, + pub query_handler: ServerQueryHandlerRef, pub script_handler: Option, } impl HttpServer { - pub fn new(sql_handler: ServerSqlQueryHandlerRef, options: HttpOptions) -> Self { + pub fn new(query_handler: ServerQueryHandlerRef, options: HttpOptions) -> Self { Self { - sql_handler, + query_handler, options, opentsdb_handler: None, influxdb_handler: None, @@ -413,7 +413,7 @@ impl HttpServer { let sql_router = self .route_sql(ApiState { - sql_handler: self.sql_handler.clone(), + query_handler: self.query_handler.clone(), script_handler: self.script_handler.clone(), }) .finish_api(&mut api) @@ -567,14 +567,14 @@ mod test { use tokio::sync::mpsc; use super::*; - use crate::query_handler::sql::{ServerSqlQueryHandlerAdaptor, SqlQueryHandler}; + use crate::query_handler::sql::{QueryHandler, ServerQueryHandlerAdaptor}; struct DummyInstance { _tx: mpsc::Sender<(String, Vec)>, } #[async_trait] - impl SqlQueryHandler for DummyInstance { + impl QueryHandler for DummyInstance { async fn statement_query( &self, _stmt: query::parser::QueryStatement, @@ -598,7 +598,7 @@ mod test { fn make_test_app(tx: mpsc::Sender<(String, Vec)>) -> Router { let instance = Arc::new(DummyInstance { _tx: tx }); - let instance = ServerSqlQueryHandlerAdaptor::arc(instance); + let instance = ServerQueryHandlerAdaptor::arc(instance); let server = HttpServer::new(instance, HttpOptions::default()); server.make_app().route( "/test/timeout", diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 651b33595301..6bc609703671 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -41,13 +41,13 @@ pub async fn sql( // TODO(fys): pass _user_info into query context _user_info: Extension, ) -> Json { - let sql_handler = &state.sql_handler; + let query_handler = &state.query_handler; let start = Instant::now(); let resp = if let Some(sql) = ¶ms.sql { - match super::query_context_from_db(sql_handler.clone(), params.db) { + match super::query_context_from_db(query_handler.clone(), params.db) { Ok(query_ctx) => { JsonResponse::from_output( - sql_handler + query_handler .query_multiple(QueryLanguage::Sql(sql.clone()), query_ctx) .await, ) @@ -78,12 +78,12 @@ pub async fn promql( // TODO(fys): pass _user_info into query context _user_info: Extension, ) -> Json { - let sql_handler = &state.sql_handler; + let query_handler = &state.query_handler; let start = Instant::now(); - let resp = match super::query_context_from_db(sql_handler.clone(), None) { + let resp = match super::query_context_from_db(query_handler.clone(), None) { Ok(query_ctx) => { JsonResponse::from_output( - sql_handler + query_handler .query_multiple(QueryLanguage::Promql(params.query), query_ctx) .await, ) diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 3e0a317f389e..a28c25341b25 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -32,11 +32,11 @@ use tokio::io::AsyncWrite; use crate::auth::{Identity, Password, UserProviderRef}; use crate::error::{self, Result}; use crate::mysql::writer::MysqlResultWriter; -use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::query_handler::sql::ServerQueryHandlerRef; // An intermediate shim for executing MySQL queries. pub struct MysqlInstanceShim { - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, salt: [u8; 20], session: Arc, user_provider: Option, @@ -44,7 +44,7 @@ pub struct MysqlInstanceShim { impl MysqlInstanceShim { pub fn create( - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, user_provider: Option, client_addr: SocketAddr, ) -> MysqlInstanceShim { diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index 4653e660234f..bc6124aa9f44 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -31,7 +31,7 @@ use tokio_rustls::rustls::ServerConfig; use crate::auth::UserProviderRef; use crate::error::{Error, Result}; use crate::mysql::handler::MysqlInstanceShim; -use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::query_handler::sql::ServerQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; // Default size of ResultSet write buffer: 100KB @@ -40,13 +40,13 @@ const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024; /// [`MysqlSpawnRef`] stores arc refs /// that should be passed to new [`MysqlInstanceShim`]s. pub struct MysqlSpawnRef { - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, user_provider: Option, } impl MysqlSpawnRef { pub fn new( - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, user_provider: Option, ) -> MysqlSpawnRef { MysqlSpawnRef { @@ -55,7 +55,7 @@ impl MysqlSpawnRef { } } - fn query_handler(&self) -> ServerSqlQueryHandlerRef { + fn query_handler(&self) -> ServerQueryHandlerRef { self.query_handler.clone() } fn user_provider(&self) -> Option { diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index b2160211a980..07eaa408a1f8 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -36,7 +36,7 @@ use session::context::{QueryContext, QueryContextRef}; use self::auth_handler::PgLoginVerifier; use crate::auth::UserProviderRef; -use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::query_handler::sql::ServerQueryHandlerRef; pub(crate) struct GreptimeDBStartupParameters { version: &'static str, @@ -66,7 +66,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters { } pub struct PostgresServerHandler { - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, login_verifier: PgLoginVerifier, force_tls: bool, param_provider: Arc, @@ -78,7 +78,7 @@ pub struct PostgresServerHandler { #[derive(Builder)] pub(crate) struct MakePostgresServerHandler { - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, user_provider: Option, #[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")] param_provider: Arc, diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index b9488da844c1..806e545aaded 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -29,7 +29,7 @@ use super::PostgresServerHandler; use crate::auth::{Identity, Password, UserProviderRef}; use crate::error; use crate::error::Result; -use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::query_handler::sql::ServerQueryHandlerRef; pub(crate) struct PgLoginVerifier { user_provider: Option, @@ -238,7 +238,7 @@ enum DbResolution { /// A function extracted to resolve lifetime and readability issues: fn resolve_db_info( client: &mut C, - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, ) -> PgWireResult where C: ClientInfo + Unpin + Send, diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 89018a10815d..1d1290165264 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -28,7 +28,7 @@ use tokio_rustls::TlsAcceptor; use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder}; use crate::auth::UserProviderRef; use crate::error::Result; -use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::query_handler::sql::ServerQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; use crate::tls::TlsOption; @@ -41,7 +41,7 @@ pub struct PostgresServer { impl PostgresServer { /// Creates a new Postgres server with provided query_handler and async runtime pub fn new( - query_handler: ServerSqlQueryHandlerRef, + query_handler: ServerQueryHandlerRef, tls: TlsOption, io_runtime: Arc, user_provider: Option, diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index 7f5312451c2c..f06c83070704 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -22,18 +22,21 @@ use session::context::QueryContextRef; use crate::error::{self, Result}; -pub type SqlQueryHandlerRef = Arc; -pub type ServerSqlQueryHandlerRef = SqlQueryHandlerRef; +pub type QueryHandlerRef = Arc; +pub type ServerQueryHandlerRef = QueryHandlerRef; #[async_trait] -pub trait SqlQueryHandler { - /// Execute a query statement. +pub trait QueryHandler { + /// Execute a [QueryStatement]. async fn statement_query( &self, stmt: QueryStatement, query_ctx: QueryContextRef, ) -> Result; + /// Check if the given catalog and schema are valid. + fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result; + async fn query(&self, query: QueryLanguage, query_ctx: QueryContextRef) -> Result { let stmt = QueryLanguageParser::parse(query) .map_err(BoxedError::new) @@ -44,6 +47,7 @@ pub trait SqlQueryHandler { .context(error::ExecuteQueryStatementSnafu) } + /// Execute a [QueryLanguage] that may return multiple [Output]s. async fn query_multiple( &self, query: QueryLanguage, @@ -74,20 +78,18 @@ pub trait SqlQueryHandler { QueryLanguage::Promql(_) => vec![self.query(query, query_ctx).await], } } - - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result; } -pub struct ServerSqlQueryHandlerAdaptor(SqlQueryHandlerRef); +pub struct ServerQueryHandlerAdaptor(QueryHandlerRef); -impl ServerSqlQueryHandlerAdaptor { - pub fn arc(handler: SqlQueryHandlerRef) -> Arc { +impl ServerQueryHandlerAdaptor { + pub fn arc(handler: QueryHandlerRef) -> Arc { Arc::new(Self(handler)) } } #[async_trait] -impl SqlQueryHandler for ServerSqlQueryHandlerAdaptor { +impl QueryHandler for ServerQueryHandlerAdaptor { async fn statement_query( &self, stmt: QueryStatement, diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index 0ff270153372..c3d1ecdc7af1 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -26,10 +26,10 @@ use crate::{create_testing_script_handler, create_testing_sql_query_handler}; #[tokio::test] async fn test_sql_not_provided() { - let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let query_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); let Json(json) = http_handler::sql( State(ApiState { - sql_handler, + query_handler, script_handler: None, }), Query(http_handler::SqlQuery::default()), @@ -49,11 +49,11 @@ async fn test_sql_output_rows() { common_telemetry::init_default_ut_logging(); let query = create_query(); - let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let query_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); let Json(json) = http_handler::sql( State(ApiState { - sql_handler, + query_handler, script_handler: None, }), query, @@ -90,13 +90,13 @@ def test(n): return n; "# .to_string(); - let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let query_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); let script_handler = create_testing_script_handler(MemTable::default_numbers_table()); let body = RawBody(Body::from(script.clone())); let invalid_query = create_invalid_script_query(); let Json(json) = script_handler::scripts( State(ApiState { - sql_handler: sql_handler.clone(), + query_handler: query_handler.clone(), script_handler: Some(script_handler.clone()), }), invalid_query, @@ -110,7 +110,7 @@ def test(n): let exec = create_script_query(); let Json(json) = script_handler::scripts( State(ApiState { - sql_handler, + query_handler, script_handler: Some(script_handler), }), exec, diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 5ddfa757e976..e42fdec5c30d 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -22,7 +22,7 @@ use common_query::Output; use servers::error::Result; use servers::http::{HttpOptions, HttpServer}; use servers::influxdb::InfluxdbRequest; -use servers::query_handler::sql::SqlQueryHandler; +use servers::query_handler::sql::QueryHandler; use servers::query_handler::InfluxdbLineProtocolHandler; use session::context::QueryContextRef; use tokio::sync::mpsc; @@ -47,7 +47,7 @@ impl InfluxdbLineProtocolHandler for DummyInstance { } #[async_trait] -impl SqlQueryHandler for DummyInstance { +impl QueryHandler for DummyInstance { async fn statement_query( &self, _stmt: query::parser::QueryStatement, diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index cd18138cdf93..b4948934f75b 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -21,7 +21,7 @@ use common_query::Output; use servers::error::{self, Result}; use servers::http::{HttpOptions, HttpServer}; use servers::opentsdb::codec::DataPoint; -use servers::query_handler::sql::SqlQueryHandler; +use servers::query_handler::sql::QueryHandler; use servers::query_handler::OpentsdbProtocolHandler; use session::context::QueryContextRef; use tokio::sync::mpsc; @@ -45,7 +45,7 @@ impl OpentsdbProtocolHandler for DummyInstance { } #[async_trait] -impl SqlQueryHandler for DummyInstance { +impl QueryHandler for DummyInstance { async fn statement_query( &self, _stmt: query::parser::QueryStatement, diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index a06425b60434..5edbdd1ba50f 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -26,7 +26,7 @@ use servers::error::Result; use servers::http::{HttpOptions, HttpServer}; use servers::prometheus; use servers::prometheus::{snappy_compress, Metrics}; -use servers::query_handler::sql::SqlQueryHandler; +use servers::query_handler::sql::QueryHandler; use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse}; use session::context::QueryContextRef; use tokio::sync::mpsc; @@ -70,7 +70,7 @@ impl PrometheusProtocolHandler for DummyInstance { } #[async_trait] -impl SqlQueryHandler for DummyInstance { +impl QueryHandler for DummyInstance { async fn statement_query( &self, _stmt: query::parser::QueryStatement, diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index b860ea1d9688..3021753b0e82 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -25,7 +25,7 @@ use query::{QueryEngineFactory, QueryEngineRef}; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; use servers::error::Result; -use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler}; +use servers::query_handler::sql::{QueryHandler, ServerQueryHandlerRef}; use servers::query_handler::{ScriptHandler, ScriptHandlerRef}; use session::context::QueryContextRef; use table::test_util::MemTable; @@ -55,7 +55,7 @@ impl DummyInstance { } #[async_trait] -impl SqlQueryHandler for DummyInstance { +impl QueryHandler for DummyInstance { async fn statement_query( &self, stmt: QueryStatement, @@ -121,6 +121,6 @@ fn create_testing_script_handler(table: MemTable) -> ScriptHandlerRef { Arc::new(create_testing_instance(table)) as _ } -fn create_testing_sql_query_handler(table: MemTable) -> ServerSqlQueryHandlerRef { +fn create_testing_sql_query_handler(table: MemTable) -> ServerQueryHandlerRef { Arc::new(create_testing_instance(table)) as _ } diff --git a/src/servers/tests/py_script/mod.rs b/src/servers/tests/py_script/mod.rs index 47560c5203e5..6a0cc766f0d0 100644 --- a/src/servers/tests/py_script/mod.rs +++ b/src/servers/tests/py_script/mod.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use query::parser::QueryLanguage; use servers::error::Result; -use servers::query_handler::sql::SqlQueryHandler; +use servers::query_handler::sql::QueryHandler; use servers::query_handler::ScriptHandler; use session::context::QueryContext; use table::test_util::MemTable; diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 3232b846289f..91cdb8964870 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -42,7 +42,7 @@ use servers::grpc::GrpcServer; use servers::http::{HttpOptions, HttpServer}; use servers::promql::PromqlServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor; -use servers::query_handler::sql::ServerSqlQueryHandlerAdaptor; +use servers::query_handler::sql::ServerQueryHandlerAdaptor; use servers::server::Server; use servers::Mode; use snafu::ResultExt; @@ -278,7 +278,7 @@ pub async fn setup_test_http_app(store_type: StorageType, name: &str) -> (Router .await .unwrap(); let http_server = HttpServer::new( - ServerSqlQueryHandlerAdaptor::arc(instance), + ServerQueryHandlerAdaptor::arc(instance), HttpOptions::default(), ); (http_server.make_app(), guard) @@ -300,7 +300,7 @@ pub async fn setup_test_http_app_with_frontend( .await .unwrap(); let mut http_server = HttpServer::new( - ServerSqlQueryHandlerAdaptor::arc(Arc::new(frontend)), + ServerQueryHandlerAdaptor::arc(Arc::new(frontend)), HttpOptions::default(), ); http_server.set_script_handler(instance.clone()); From 55ab5468e5b47ea9994bcb215ceb162de9331193 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 7 Feb 2023 01:28:21 +0800 Subject: [PATCH 7/9] fix clippy Signed-off-by: Ruihang Xia --- src/servers/src/interceptor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/servers/src/interceptor.rs b/src/servers/src/interceptor.rs index 5076ff0333a5..01c4049c8258 100644 --- a/src/servers/src/interceptor.rs +++ b/src/servers/src/interceptor.rs @@ -27,7 +27,7 @@ pub trait SqlQueryInterceptor { /// Called before a query is parsed into statement. /// The implementation is allowed to change the query if needed. - fn pre_parsing<'a>( + fn pre_parsing( &self, query: QueryLanguage, _query_ctx: QueryContextRef, From 7a9b600d98de979f690c62cfd40ed9939e5a5464 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 7 Feb 2023 11:50:31 +0800 Subject: [PATCH 8/9] fix errors from merge commit Signed-off-by: Ruihang Xia --- src/datanode/src/instance/sql.rs | 15 ++++++++++----- src/frontend/src/instance.rs | 8 ++++++-- src/frontend/src/instance/distributed.rs | 13 +++++++++---- src/frontend/src/instance/standalone.rs | 10 ++++++---- src/query/src/datafusion.rs | 5 ++--- src/servers/src/http.rs | 7 ++++--- src/servers/src/postgres/handler.rs | 9 ++++----- src/servers/src/query_handler/sql.rs | 7 ++++--- src/servers/tests/http/influxdb_test.rs | 9 +++++---- src/servers/tests/http/opentsdb_test.rs | 7 ++++--- src/servers/tests/http/prometheus_test.rs | 7 ++++--- src/servers/tests/mod.rs | 11 ++++------- 12 files changed, 62 insertions(+), 46 deletions(-) diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index f4fcc6a1229f..f74a6f3ab3aa 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -19,7 +19,7 @@ use common_recordbatch::RecordBatches; use common_telemetry::logging::info; use common_telemetry::timer; use datatypes::schema::Schema; -use query::parser::{QueryLanguageParser, QueryStatement}; +use query::parser::{QueryLanguage, QueryLanguageParser, QueryStatement}; use servers::error as server_error; use servers::promql::PromqlHandler; use servers::query_handler::sql::QueryHandler; @@ -229,12 +229,17 @@ impl QueryHandler for Instance { .context(server_error::CheckDatabaseValiditySnafu) } - fn do_describe(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result> { - if let Statement::Query(_) = stmt { + fn describe( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> server_error::Result> { + if let QueryStatement::Sql(Statement::Query(_)) = stmt { self.query_engine - .describe(QueryStatement::Sql(stmt), query_ctx) + .describe(stmt, query_ctx) .map(Some) - .context(error::DescribeStatementSnafu) + .map_err(BoxedError::new) + .context(server_error::DescribeStatementSnafu) } else { Ok(None) } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 1b43770ccaea..31e93e415a4c 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -509,8 +509,12 @@ impl QueryHandler for Instance { .context(server_error::CheckDatabaseValiditySnafu) } - fn do_describe(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result> { - self.sql_handler.do_describe(stmt, query_ctx) + fn describe( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> server_error::Result> { + self.sql_handler.describe(stmt, query_ctx) } } diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index b13ff062c807..00d65f384ae5 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -403,12 +403,17 @@ impl QueryHandler for DistInstance { .context(server_error::CheckDatabaseValiditySnafu) } - fn do_describe(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result> { - if let Statement::Query(_) = stmt { + fn describe( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> server_error::Result> { + if let QueryStatement::Sql(Statement::Query(_)) = stmt { self.query_engine - .describe(QueryStatement::Sql(stmt), query_ctx) + .describe(stmt, query_ctx) .map(Some) - .context(error::DescribeStatementSnafu) + .map_err(BoxedError::new) + .context(server_error::DescribeStatementSnafu) } else { Ok(None) } diff --git a/src/frontend/src/instance/standalone.rs b/src/frontend/src/instance/standalone.rs index f09a5486342c..5e9e66675d30 100644 --- a/src/frontend/src/instance/standalone.rs +++ b/src/frontend/src/instance/standalone.rs @@ -49,10 +49,12 @@ impl QueryHandler for StandaloneSqlQueryHandler { self.0.is_valid_schema(catalog, schema) } - fn do_describe(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result> { - self.0 - .do_describe(stmt, query_ctx) - .context(error::InvokeDatanodeSnafu) + fn describe( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> server_error::Result> { + self.0.describe(stmt, query_ctx) } } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index be89b6148203..74c41e38f795 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -369,9 +369,8 @@ mod tests { #[test] fn test_describe() { let engine = create_test_engine(); - let sql = "select sum(number) from numbers limit 20"; - - let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let sql = QueryLanguage::Sql("select sum(number) from numbers limit 20".to_string()); + let stmt = QueryLanguageParser::parse(sql).unwrap(); let schema = engine .describe(stmt, Arc::new(QueryContext::new())) diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 08d7b305d8c0..3f4ff77263cb 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -563,6 +563,7 @@ mod test { use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{StringVector, UInt32Vector}; + use query::parser::QueryStatement; use session::context::QueryContextRef; use tokio::sync::mpsc; @@ -577,15 +578,15 @@ mod test { impl QueryHandler for DummyInstance { async fn statement_query( &self, - _stmt: query::parser::QueryStatement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() } - fn do_describe( + fn describe( &self, - _stmt: sql::statements::statement::Statement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result> { unimplemented!() diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 1671867d81cf..b49a3c5e7f12 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -34,7 +34,7 @@ use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::store::MemPortalStore; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; -use query::parser::QueryLanguage; +use query::parser::{QueryLanguage, QueryStatement}; use sql::dialect::GenericDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -376,9 +376,8 @@ impl ExtendedQueryHandler for PostgresServerHandler { let output = self .query_handler - .do_query(&sql, self.query_ctx.clone()) - .await - .remove(0); + .query(QueryLanguage::Sql(sql), self.query_ctx.clone()) + .await; output_to_query_response(output, false) } @@ -394,7 +393,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { let (stmt, _) = statement.statement(); if let Some(schema) = self .query_handler - .do_describe(stmt.clone(), self.query_ctx.clone()) + .describe(QueryStatement::Sql(stmt.clone()), self.query_ctx.clone()) .map_err(|e| PgWireError::ApiError(Box::new(e)))? { schema_to_pg(&schema).map_err(|e| PgWireError::ApiError(Box::new(e))) diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index 8dcaf4764c59..83b6950964e9 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -49,7 +49,8 @@ pub trait QueryHandler { } // TODO(LFC): revisit this for mysql prepared statement - fn do_describe(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result>; + /// Retrieve query schema without execute it. + fn describe(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result>; /// Execute a [QueryLanguage] that may return multiple [Output]s. async fn query_multiple( @@ -106,9 +107,9 @@ impl QueryHandler for ServerQueryHandlerAdaptor { .context(error::ExecuteStatementSnafu) } - fn do_describe(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result> { + fn describe(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result> { self.0 - .do_describe(stmt, query_ctx) + .describe(stmt, query_ctx) .map_err(BoxedError::new) .context(error::DescribeStatementSnafu) } diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 397941776320..1f29c7dbb39f 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -20,7 +20,8 @@ use axum::{http, Router}; use axum_test_helper::TestClient; use common_query::Output; use datatypes::schema::Schema; -use servers::error::{Error, Result}; +use query::parser::QueryStatement; +use servers::error::Result; use servers::http::{HttpOptions, HttpServer}; use servers::influxdb::InfluxdbRequest; use servers::query_handler::sql::QueryHandler; @@ -51,15 +52,15 @@ impl InfluxdbLineProtocolHandler for DummyInstance { impl QueryHandler for DummyInstance { async fn statement_query( &self, - _stmt: query::parser::QueryStatement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() } - fn do_describe( + fn describe( &self, - _stmt: sql::statements::statement::Statement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result> { unimplemented!() diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 68faf60833ff..e576bed339ca 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -19,6 +19,7 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use datatypes::schema::Schema; +use query::parser::QueryStatement; use servers::error::{self, Result}; use servers::http::{HttpOptions, HttpServer}; use servers::opentsdb::codec::DataPoint; @@ -49,15 +50,15 @@ impl OpentsdbProtocolHandler for DummyInstance { impl QueryHandler for DummyInstance { async fn statement_query( &self, - _stmt: query::parser::QueryStatement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() } - fn do_describe( + fn describe( &self, - _stmt: sql::statements::statement::Statement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result> { unimplemented!() diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index 43b5ad4f8cae..68261e25054f 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -23,6 +23,7 @@ use axum_test_helper::TestClient; use common_query::Output; use datatypes::schema::Schema; use prost::Message; +use query::parser::QueryStatement; use servers::error::Result; use servers::http::{HttpOptions, HttpServer}; use servers::prometheus; @@ -74,15 +75,15 @@ impl PrometheusProtocolHandler for DummyInstance { impl QueryHandler for DummyInstance { async fn statement_query( &self, - _stmt: query::parser::QueryStatement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() } - fn do_describe( + fn describe( &self, - _stmt: sql::statements::statement::Statement, + _stmt: QueryStatement, _query_ctx: QueryContextRef, ) -> Result> { unimplemented!() diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 0cede0c3e975..49c23c3baf59 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -21,7 +21,7 @@ use catalog::{CatalogList, CatalogProvider, SchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use datatypes::schema::Schema; -use query::parser::{QueryLanguageParser, QueryStatement}; +use query::parser::QueryStatement; use query::{QueryEngineFactory, QueryEngineRef}; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; @@ -70,12 +70,9 @@ impl QueryHandler for DummyInstance { Ok(self.query_engine.execute(&plan).await.unwrap()) } - fn do_describe(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result> { - if let Statement::Query(_) = stmt { - let schema = self - .query_engine - .describe(QueryStatement::Sql(stmt), query_ctx) - .unwrap(); + fn describe(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result> { + if let QueryStatement::Sql(Statement::Query(_)) = stmt { + let schema = self.query_engine.describe(stmt, query_ctx).unwrap(); Ok(Some(schema)) } else { Ok(None) From 648bba98de40f7ea9ce2560fd71c3f55b681a4b2 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 7 Feb 2023 15:52:11 +0800 Subject: [PATCH 9/9] fix sqlness test Signed-off-by: Ruihang Xia --- src/query/src/error.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/query/src/error.rs b/src/query/src/error.rs index 26b88107de38..9ca0fe4f6578 100644 --- a/src/query/src/error.rs +++ b/src/query/src/error.rs @@ -78,7 +78,7 @@ impl ErrorExt for Error { use Error::*; match self { - QueryParse { .. } | MultipleStatements { .. } => StatusCode::InvalidSyntax, + MultipleStatements { .. } => StatusCode::InvalidSyntax, UnsupportedExpr { .. } | CatalogNotFound { .. } | SchemaNotFound { .. } @@ -87,7 +87,9 @@ impl ErrorExt for Error { VectorComputation { source } => source.status_code(), CreateRecordBatch { source } => source.status_code(), Datatype { source } => source.status_code(), - QueryExecution { source } | QueryPlan { source } => source.status_code(), + QueryExecution { source } | QueryPlan { source } | QueryParse { source, .. } => { + source.status_code() + } } }