From 3a226272fc11bb79ca33379cf15c27d9b00a528c Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 5 Jul 2023 14:32:09 +0800 Subject: [PATCH 01/10] generic make connection Signed-off-by: Bugen Zhao --- examples/basic/examples/basic.rs | 2 +- examples/condition/examples/condition.rs | 2 +- examples/custom_type/examples/custom_type.rs | 2 +- .../examples/file_level_sort_mode.rs | 2 +- examples/include/examples/include.rs | 2 +- examples/rowsort/examples/rowsort.rs | 2 +- .../examples/test_dir_escape.rs | 2 +- examples/validator/examples/validator.rs | 2 +- sqllogictest-bin/src/main.rs | 22 +-- sqllogictest/src/harness.rs | 4 +- sqllogictest/src/parser.rs | 37 ++++ sqllogictest/src/runner.rs | 174 +++++++++++++++--- 12 files changed, 203 insertions(+), 50 deletions(-) diff --git a/examples/basic/examples/basic.rs b/examples/basic/examples/basic.rs index b28713b..775ed74 100644 --- a/examples/basic/examples/basic.rs +++ b/examples/basic/examples/basic.rs @@ -49,7 +49,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(FakeDB); + let mut tester = sqllogictest::Runner::new_once(FakeDB); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/condition/examples/condition.rs b/examples/condition/examples/condition.rs index 3595ea9..e62cd0a 100644 --- a/examples/condition/examples/condition.rs +++ b/examples/condition/examples/condition.rs @@ -43,7 +43,7 @@ impl sqllogictest::DB for FakeDB { fn main() { for engine_name in ["risinglight", "otherdb"] { - let mut tester = sqllogictest::Runner::new(FakeDB { engine_name }); + let mut tester = sqllogictest::Runner::new_once(FakeDB { engine_name }); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/custom_type/examples/custom_type.rs b/examples/custom_type/examples/custom_type.rs index 347d882..c8c31ab 100644 --- a/examples/custom_type/examples/custom_type.rs +++ b/examples/custom_type/examples/custom_type.rs @@ -67,7 +67,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(FakeDB); + let mut tester = sqllogictest::Runner::new_once(FakeDB); tester.with_column_validator(strict_column_validator); let mut filename = PathBuf::from(file!()); diff --git a/examples/file_level_sort_mode/examples/file_level_sort_mode.rs b/examples/file_level_sort_mode/examples/file_level_sort_mode.rs index 9bd5793..d08e65a 100644 --- a/examples/file_level_sort_mode/examples/file_level_sort_mode.rs +++ b/examples/file_level_sort_mode/examples/file_level_sort_mode.rs @@ -41,7 +41,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(FakeDB); + let mut tester = sqllogictest::Runner::new_once(FakeDB); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/include/examples/include.rs b/examples/include/examples/include.rs index 6a4167e..b10c583 100644 --- a/examples/include/examples/include.rs +++ b/examples/include/examples/include.rs @@ -44,7 +44,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(FakeDB); + let mut tester = sqllogictest::Runner::new_once(FakeDB); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/rowsort/examples/rowsort.rs b/examples/rowsort/examples/rowsort.rs index b47dfc7..b2b3302 100644 --- a/examples/rowsort/examples/rowsort.rs +++ b/examples/rowsort/examples/rowsort.rs @@ -41,7 +41,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(FakeDB); + let mut tester = sqllogictest::Runner::new_once(FakeDB); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/test_dir_escape/examples/test_dir_escape.rs b/examples/test_dir_escape/examples/test_dir_escape.rs index 9d3a936..6982554 100644 --- a/examples/test_dir_escape/examples/test_dir_escape.rs +++ b/examples/test_dir_escape/examples/test_dir_escape.rs @@ -28,7 +28,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(FakeDB); + let mut tester = sqllogictest::Runner::new_once(FakeDB); // enable `__TEST_DIR__` override tester.enable_testdir(); diff --git a/examples/validator/examples/validator.rs b/examples/validator/examples/validator.rs index 05e4a67..6fd0d67 100644 --- a/examples/validator/examples/validator.rs +++ b/examples/validator/examples/validator.rs @@ -28,7 +28,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(FakeDB); + let mut tester = sqllogictest::Runner::new_once(FakeDB); // Validator will always return true. tester.with_validator(|_, _| true); diff --git a/sqllogictest-bin/src/main.rs b/sqllogictest-bin/src/main.rs index 04eca15..5997eaa 100644 --- a/sqllogictest-bin/src/main.rs +++ b/sqllogictest-bin/src/main.rs @@ -17,7 +17,7 @@ use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite}; use rand::seq::SliceRandom; use sqllogictest::{ default_validator, strict_column_validator, update_record_with_output, AsyncDB, Injected, - Record, Runner, + MakeConnection, Record, Runner, }; #[derive(Default, Copy, Clone, Debug, PartialEq, Eq, ArgEnum)] @@ -370,7 +370,7 @@ async fn run_serial( for file in files { let engine = engines::connect(engine, &config).await?; - let mut runner = Runner::new(engine); + let mut runner = Runner::new_once(engine); for label in labels { runner.add_label(label); } @@ -419,7 +419,7 @@ async fn update_test_files( ) -> Result<()> { for file in files { let engine = engines::connect(engine, &config).await?; - let runner = Runner::new(engine); + let runner = Runner::new_once(engine); if let Err(e) = update_test_file(&mut std::io::stdout(), runner, &file, format).await { { @@ -444,7 +444,7 @@ async fn connect_and_run_test_file( labels: &[String], ) -> Result { let engine = engines::connect(engine, &config).await?; - let mut runner = Runner::new(engine); + let mut runner = Runner::new_once(engine); for label in labels { runner.add_label(label); } @@ -455,9 +455,9 @@ async fn connect_and_run_test_file( /// Different from [`Runner::run_file_async`], we re-implement it here to print some progress /// information. -async fn run_test_file( +async fn run_test_file( out: &mut T, - mut runner: Runner, + mut runner: Runner, filename: impl AsRef, ) -> Result { let filename = filename.as_ref(); @@ -558,9 +558,9 @@ fn finish_test_file( /// Different from [`sqllogictest::update_test_file`], we re-implement it here to print some /// progress information. -async fn update_test_file( +async fn update_test_file( out: &mut T, - mut runner: Runner, + mut runner: Runner, filename: impl AsRef, format: bool, ) -> Result<()> { @@ -713,10 +713,10 @@ async fn update_test_file( Ok(()) } -async fn update_record( +async fn update_record( outfile: &mut File, - runner: &mut Runner, - record: Record, + runner: &mut Runner, + record: Record<::ColumnType>, format: bool, ) -> Result<()> { assert!(!matches!(record, Record::Injected(_))); diff --git a/sqllogictest/src/harness.rs b/sqllogictest/src/harness.rs index 6e1147f..e52aa05 100644 --- a/sqllogictest/src/harness.rs +++ b/sqllogictest/src/harness.rs @@ -3,7 +3,7 @@ use std::path::Path; pub use glob::glob; pub use libtest_mimic::{run, Arguments, Failed, Trial}; -use crate::{AsyncDB, Runner}; +use crate::{AsyncDB, MakeOnce, Runner}; /// * `db_fn`: `fn() -> sqllogictest::AsyncDB` /// * `pattern`: The glob used to match against and select each file to be tested. It is relative to @@ -33,7 +33,7 @@ macro_rules! harness { } pub fn test(filename: impl AsRef, db: impl AsyncDB) -> Result<(), Failed> { - let mut tester = Runner::new(db); + let mut tester = Runner::new(MakeOnce::new(db)); tester.run_file(filename)?; Ok(()) } diff --git a/sqllogictest/src/parser.rs b/sqllogictest/src/parser.rs index f65eec0..2b91448 100644 --- a/sqllogictest/src/parser.rs +++ b/sqllogictest/src/parser.rs @@ -82,6 +82,7 @@ pub enum Record { Statement { loc: Location, conditions: Vec, + connection: Connection, /// The SQL command is expected to fail with an error messages that matches the given /// regex. If the regex is an empty string, any error message is accepted. #[educe(PartialEq(method = "cmp_regex"))] @@ -96,6 +97,7 @@ pub enum Record { Query { loc: Location, conditions: Vec, + connection: Connection, expected_types: Vec, sort_mode: Option, label: Option, @@ -136,6 +138,7 @@ pub enum Record { threshold: u64, }, Condition(Condition), + Connection(Connection), Comment(Vec), Newline, /// Internally injected record which should not occur in the test file. @@ -175,6 +178,7 @@ impl std::fmt::Display for Record { Record::Statement { loc: _, conditions: _, + connection: _, expected_error, sql, expected_count, @@ -199,6 +203,7 @@ impl std::fmt::Display for Record { Record::Query { loc: _, conditions: _, + connection: _, expected_types, sort_mode, label, @@ -249,6 +254,12 @@ impl std::fmt::Display for Record { Condition::OnlyIf { label } => write!(f, "onlyif {label}"), Condition::SkipIf { label } => write!(f, "skipif {label}"), }, + Record::Connection(conn) => { + if let Connection::Named(conn) = conn { + write!(f, "connection {}", conn)?; + } + Ok(()) + } Record::HashThreshold { loc: _, threshold } => { write!(f, "hash-threshold {threshold}") } @@ -301,6 +312,22 @@ impl Condition { } } +#[derive(Default, Debug, PartialEq, Eq, Hash, Clone)] +pub enum Connection { + #[default] + Default, + Named(String), +} + +impl Connection { + fn new(name: impl AsRef) -> Self { + match name.as_ref() { + "default" => Self::Default, + name => Self::Named(name.to_owned()), + } + } +} + /// Whether to apply sorting before checking the results of a query. #[derive(Debug, PartialEq, Eq, Clone)] pub enum SortMode { @@ -404,6 +431,7 @@ fn parse_inner(loc: &Location, script: &str) -> Result(loc: &Location, script: &str) -> Result { + let conn = Connection::new(label); + connection = conn.clone(); + records.push(Record::Connection(conn)); + } ["statement", res @ ..] => { let mut expected_count = None; let mut expected_error = None; @@ -499,6 +532,7 @@ fn parse_inner(loc: &Location, script: &str) -> Result(loc: &Location, script: &str) -> Result", 1), conditions: vec![], + connection: Connection::Default, expected_types: vec![], sort_mode: None, label: None, @@ -799,6 +835,7 @@ select * from foo; // so if new variants are added, this match // statement must be too. Record::Condition(_) + | Record::Connection(_) | Record::Comment(_) | Record::Control(_) | Record::Newline diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 9100b84..63a51bc 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -1,6 +1,6 @@ //! Sqllogictest runner. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{Debug, Display}; use std::path::Path; use std::sync::Arc; @@ -9,7 +9,8 @@ use std::vec; use async_trait::async_trait; use futures::executor::block_on; -use futures::{stream, Future, StreamExt}; +use futures::future::BoxFuture; +use futures::{stream, Future, FutureExt, StreamExt}; use itertools::Itertools; use md5::Digest; use owo_colors::OwoColorize; @@ -50,7 +51,7 @@ pub enum DBOutput { /// The async database to be tested. #[async_trait] -pub trait AsyncDB: Send { +pub trait AsyncDB: Send + 'static { /// The error type of SQL execution. type Error: std::error::Error + Send + Sync + 'static; /// The type of result columns @@ -75,7 +76,7 @@ pub trait AsyncDB: Send { } /// The database to be tested. -pub trait DB: Send { +pub trait DB: Send + 'static { /// The error type of SQL execution. type Error: std::error::Error + Send + Sync + 'static; /// The type of result columns @@ -450,12 +451,75 @@ pub fn strict_column_validator(actual: &Vec, expected: &Vec .any(|(actual_column, expected_column)| actual_column != expected_column) } +pub trait MakeConnection: 'static { + type D: AsyncDB; + + fn make(&mut self) -> BoxFuture::Error>>; +} + +impl MakeConnection for F +where + F: FnMut() -> Fut + 'static, + Fut: Future> + Send + 'static, +{ + type D = D; + + fn make(&mut self) -> BoxFuture::Error>> { + self().boxed() + } +} + +pub struct MakeOnce(Option); + +impl MakeOnce { + pub fn new(db: D) -> Self { + MakeOnce(Some(db)) + } +} + +impl MakeConnection for MakeOnce { + type D = D; + + fn make(&mut self) -> BoxFuture::Error>> { + async { Ok(self.0.take().expect("MakeOnce can only be used once")) }.boxed() + } +} + +struct Connections { + make_conn: M, + conns: HashMap, +} + +impl Connections { + fn new(make_conn: M) -> Self { + Connections { + make_conn, + conns: HashMap::new(), + } + } + + async fn get(&mut self, name: Connection) -> Result<&mut M::D, ::Error> { + if !self.conns.contains_key(&name) { + let conn = self.make_conn.make().await?; + self.conns.insert(name.clone(), conn); + } + Ok(self.conns.get_mut(&name).unwrap()) + } + + async fn run_default( + &mut self, + sql: &str, + ) -> Result::ColumnType>, ::Error> { + self.get(Connection::Default).await?.run(sql).await + } +} + /// Sqllogictest runner. -pub struct Runner { - db: D, +pub struct Runner { + conn: Connections, // validator is used for validate if the result of query equals to expected. validator: Validator, - column_type_validator: ColumnTypeValidator, + column_type_validator: ColumnTypeValidator<::ColumnType>, testdir: Option, sort_mode: Option, /// 0 means never hashing @@ -464,17 +528,23 @@ pub struct Runner { labels: HashSet, } -impl Runner { +impl Runner> { + pub fn new_once(db: D) -> Self { + Self::new(MakeOnce::new(db)) + } +} + +impl Runner { /// Create a new test runner on the database. - pub fn new(db: D) -> Self { + pub fn new(make_conn: M) -> Self { Runner { validator: default_validator, column_type_validator: default_column_validator, testdir: None, sort_mode: None, hash_threshold: 0, - labels: [db.engine_name().to_string()].into_iter().collect(), - db, + labels: HashSet::new(), // TODO + conn: Connections::new(make_conn), } } @@ -495,7 +565,10 @@ impl Runner { self.validator = validator; } - pub fn with_column_validator(&mut self, validator: ColumnTypeValidator) { + pub fn with_column_validator( + &mut self, + validator: ColumnTypeValidator<::ColumnType>, + ) { self.column_type_validator = validator; } @@ -505,14 +578,15 @@ impl Runner { pub async fn apply_record( &mut self, - record: Record, - ) -> RecordOutput { + record: Record<::ColumnType>, + ) -> RecordOutput<::ColumnType> { match record { Record::Statement { conditions, .. } if self.should_skip(&conditions) => { RecordOutput::Nothing } Record::Statement { conditions: _, + connection, sql, // compare result in run_async @@ -521,7 +595,18 @@ impl Runner { loc: _, } => { let sql = self.replace_keywords(sql); - let ret = self.db.run(&sql).await; + + let conn = match self.conn.get(connection).await { + Ok(conn) => conn, + Err(e) => { + return RecordOutput::Statement { + count: 0, + error: Some(Arc::new(e)), + } + } + }; + + let ret = conn.run(&sql).await; match ret { Ok(out) => match out { DBOutput::Rows { types, rows } => RecordOutput::Query { @@ -544,6 +629,7 @@ impl Runner { } Record::Query { conditions: _, + connection, sql, sort_mode, @@ -557,7 +643,19 @@ impl Runner { label: _, } => { let sql = self.replace_keywords(sql); - let (types, mut rows) = match self.db.run(&sql).await { + + let conn = match self.conn.get(connection).await { + Ok(conn) => conn, + Err(e) => { + return RecordOutput::Query { + error: Some(Arc::new(e)), + types: vec![], + rows: vec![], + } + } + }; + + let (types, mut rows) = match conn.run(&sql).await { Ok(out) => match out { DBOutput::Rows { types, rows } => (types, rows), DBOutput::StatementComplete(count) => { @@ -604,7 +702,7 @@ impl Runner { } } Record::Sleep { duration, .. } => { - D::sleep(duration).await; + ::sleep(duration).await; RecordOutput::Nothing } Record::Control(control) => match control { @@ -623,12 +721,16 @@ impl Runner { | Record::Subtest { .. } | Record::Halt { .. } | Record::Injected(_) - | Record::Condition(_) => RecordOutput::Nothing, + | Record::Condition(_) + | Record::Connection(_) => RecordOutput::Nothing, } } /// Run a single record. - pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> { + pub async fn run_async( + &mut self, + record: Record<::ColumnType>, + ) -> Result<(), TestError> { tracing::debug!(?record, "testing"); match (record.clone(), self.apply_record(record).await) { @@ -656,6 +758,7 @@ impl Runner { ( Record::Statement { loc, + connection: _, conditions: _, expected_error, sql, @@ -706,6 +809,7 @@ impl Runner { Record::Query { loc, conditions: _, + connection: _, expected_types, sort_mode: _, label: _, @@ -775,7 +879,7 @@ impl Runner { } /// Run a single record. - pub fn run(&mut self, record: Record) -> Result<(), TestError> { + pub fn run(&mut self, record: Record<::ColumnType>) -> Result<(), TestError> { futures::executor::block_on(self.run_async(record)) } @@ -784,7 +888,7 @@ impl Runner { /// The runner will stop early once a halt record is seen. pub async fn run_multi_async( &mut self, - records: impl IntoIterator>, + records: impl IntoIterator::ColumnType>>, ) -> Result<(), TestError> { for record in records.into_iter() { if let Record::Halt { .. } = record { @@ -800,7 +904,7 @@ impl Runner { /// The runner will stop early once a halt record is seen. pub fn run_multi( &mut self, - records: impl IntoIterator>, + records: impl IntoIterator::ColumnType>>, ) -> Result<(), TestError> { block_on(self.run_multi_async(records)) } @@ -856,7 +960,7 @@ impl Runner { jobs: usize, ) -> Result<(), ParallelTestError> where - Fut: Future, + Fut: Future + Send + 'static, { let files = glob::glob(glob).expect("failed to read glob pattern"); let mut tasks = vec![]; @@ -872,14 +976,14 @@ impl Runner { .expect("not a UTF-8 filename"); let db_name = db_name.replace([' ', '.', '-'], "_"); - self.db - .run(&format!("CREATE DATABASE {db_name};")) + self.conn + .run_default(&format!("CREATE DATABASE {db_name};")) .await .expect("create db failed"); let target = hosts[idx % hosts.len()].clone(); tasks.push(async move { - let db = conn_builder(target, db_name).await; - let mut tester = Runner::new(db); + let mut tester = + Runner::new(move || conn_builder(target.clone(), db_name.clone()).map(Ok)); let filename = file.to_string_lossy().to_string(); tester.run_file_async(filename).await }) @@ -906,7 +1010,7 @@ impl Runner { jobs: usize, ) -> Result<(), ParallelTestError> where - Fut: Future, + Fut: Future + Send + 'static, { block_on(self.run_parallel_async(glob, hosts, conn_builder, jobs)) } @@ -939,7 +1043,7 @@ impl Runner { filename: impl AsRef, col_separator: &str, validator: Validator, - column_type_validator: ColumnTypeValidator, + column_type_validator: ColumnTypeValidator<::ColumnType>, ) -> Result<(), Box> { use std::io::{Read, Seek, SeekFrom, Write}; use std::path::PathBuf; @@ -1090,6 +1194,7 @@ pub fn update_record_with_output( sql, loc, conditions, + connection, expected_error: None, expected_count, }, @@ -1107,6 +1212,7 @@ pub fn update_record_with_output( expected_error: None, loc, conditions, + connection, expected_count, }) } @@ -1116,6 +1222,7 @@ pub fn update_record_with_output( sql, loc, conditions, + connection, .. }, RecordOutput::Statement { error: None, .. }, @@ -1124,6 +1231,7 @@ pub fn update_record_with_output( expected_error: None, loc, conditions, + connection, expected_count: None, }), // statement, statement @@ -1131,6 +1239,7 @@ pub fn update_record_with_output( Record::Statement { loc, conditions, + connection, expected_error, sql, expected_count, @@ -1143,6 +1252,7 @@ pub fn update_record_with_output( expected_error: None, loc, conditions, + connection, expected_count: expected_count.map(|_| *count), }), // Error match @@ -1152,6 +1262,7 @@ pub fn update_record_with_output( expected_error: Some(expected_error), loc, conditions, + connection, expected_count: None, }) } @@ -1161,6 +1272,7 @@ pub fn update_record_with_output( expected_error: Some(Regex::new(®ex::escape(&e.to_string())).unwrap()), loc, conditions, + connection, expected_count: None, }), }, @@ -1169,6 +1281,7 @@ pub fn update_record_with_output( Record::Query { loc, conditions, + connection, expected_types, sort_mode, label, @@ -1187,6 +1300,7 @@ pub fn update_record_with_output( expected_error: Some(expected_error), loc, conditions, + connection, expected_types: vec![], sort_mode, label, @@ -1200,6 +1314,7 @@ pub fn update_record_with_output( expected_error: Some(Regex::new(®ex::escape(&e.to_string())).unwrap()), loc, conditions, + connection, expected_types: vec![], sort_mode, label, @@ -1227,6 +1342,7 @@ pub fn update_record_with_output( expected_error: None, loc, conditions, + connection, expected_types: types, sort_mode, label, From 96beaa29f3e5a33352a7eb197ebc8ce694f65a7e Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 5 Jul 2023 14:47:18 +0800 Subject: [PATCH 02/10] loose bound Signed-off-by: Bugen Zhao --- sqllogictest-bin/src/main.rs | 2 +- sqllogictest/src/runner.rs | 65 +++++++++++++++++++----------------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/sqllogictest-bin/src/main.rs b/sqllogictest-bin/src/main.rs index 5997eaa..4efb5e1 100644 --- a/sqllogictest-bin/src/main.rs +++ b/sqllogictest-bin/src/main.rs @@ -716,7 +716,7 @@ async fn update_test_file( async fn update_record( outfile: &mut File, runner: &mut Runner, - record: Record<::ColumnType>, + record: Record<::ColumnType>, format: bool, ) -> Result<()> { assert!(!matches!(record, Record::Injected(_))); diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 63a51bc..0c3f287 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -9,7 +9,6 @@ use std::vec; use async_trait::async_trait; use futures::executor::block_on; -use futures::future::BoxFuture; use futures::{stream, Future, FutureExt, StreamExt}; use itertools::Itertools; use md5::Digest; @@ -51,7 +50,7 @@ pub enum DBOutput { /// The async database to be tested. #[async_trait] -pub trait AsyncDB: Send + 'static { +pub trait AsyncDB { /// The error type of SQL execution. type Error: std::error::Error + Send + Sync + 'static; /// The type of result columns @@ -76,7 +75,7 @@ pub trait AsyncDB: Send + 'static { } /// The database to be tested. -pub trait DB: Send + 'static { +pub trait DB { /// The error type of SQL execution. type Error: std::error::Error + Send + Sync + 'static; /// The type of result columns @@ -95,7 +94,7 @@ pub trait DB: Send + 'static { #[async_trait] impl AsyncDB for D where - D: DB, + D: DB + Send, { type Error = D::Error; type ColumnType = D::ColumnType; @@ -451,21 +450,23 @@ pub fn strict_column_validator(actual: &Vec, expected: &Vec .any(|(actual_column, expected_column)| actual_column != expected_column) } -pub trait MakeConnection: 'static { - type D: AsyncDB; +pub trait MakeConnection { + type Conn: AsyncDB; + type MakeFuture: Future::Error>>; - fn make(&mut self) -> BoxFuture::Error>>; + fn make(&mut self) -> Self::MakeFuture; } impl MakeConnection for F where - F: FnMut() -> Fut + 'static, - Fut: Future> + Send + 'static, + F: FnMut() -> Fut, + Fut: Future>, { - type D = D; + type Conn = D; + type MakeFuture = Fut; - fn make(&mut self) -> BoxFuture::Error>> { - self().boxed() + fn make(&mut self) -> Self::MakeFuture { + self() } } @@ -478,16 +479,17 @@ impl MakeOnce { } impl MakeConnection for MakeOnce { - type D = D; + type Conn = D; + type MakeFuture = futures::future::Ready>; - fn make(&mut self) -> BoxFuture::Error>> { - async { Ok(self.0.take().expect("MakeOnce can only be used once")) }.boxed() + fn make(&mut self) -> Self::MakeFuture { + futures::future::ready(Ok(self.0.take().expect("MakeOnce can only be used once"))) } } struct Connections { make_conn: M, - conns: HashMap, + conns: HashMap, } impl Connections { @@ -498,7 +500,7 @@ impl Connections { } } - async fn get(&mut self, name: Connection) -> Result<&mut M::D, ::Error> { + async fn get(&mut self, name: Connection) -> Result<&mut M::Conn, ::Error> { if !self.conns.contains_key(&name) { let conn = self.make_conn.make().await?; self.conns.insert(name.clone(), conn); @@ -509,7 +511,7 @@ impl Connections { async fn run_default( &mut self, sql: &str, - ) -> Result::ColumnType>, ::Error> { + ) -> Result::ColumnType>, ::Error> { self.get(Connection::Default).await?.run(sql).await } } @@ -519,7 +521,7 @@ pub struct Runner { conn: Connections, // validator is used for validate if the result of query equals to expected. validator: Validator, - column_type_validator: ColumnTypeValidator<::ColumnType>, + column_type_validator: ColumnTypeValidator<::ColumnType>, testdir: Option, sort_mode: Option, /// 0 means never hashing @@ -567,7 +569,7 @@ impl Runner { pub fn with_column_validator( &mut self, - validator: ColumnTypeValidator<::ColumnType>, + validator: ColumnTypeValidator<::ColumnType>, ) { self.column_type_validator = validator; } @@ -578,8 +580,8 @@ impl Runner { pub async fn apply_record( &mut self, - record: Record<::ColumnType>, - ) -> RecordOutput<::ColumnType> { + record: Record<::ColumnType>, + ) -> RecordOutput<::ColumnType> { match record { Record::Statement { conditions, .. } if self.should_skip(&conditions) => { RecordOutput::Nothing @@ -702,7 +704,7 @@ impl Runner { } } Record::Sleep { duration, .. } => { - ::sleep(duration).await; + ::sleep(duration).await; RecordOutput::Nothing } Record::Control(control) => match control { @@ -729,7 +731,7 @@ impl Runner { /// Run a single record. pub async fn run_async( &mut self, - record: Record<::ColumnType>, + record: Record<::ColumnType>, ) -> Result<(), TestError> { tracing::debug!(?record, "testing"); @@ -879,7 +881,10 @@ impl Runner { } /// Run a single record. - pub fn run(&mut self, record: Record<::ColumnType>) -> Result<(), TestError> { + pub fn run( + &mut self, + record: Record<::ColumnType>, + ) -> Result<(), TestError> { futures::executor::block_on(self.run_async(record)) } @@ -888,7 +893,7 @@ impl Runner { /// The runner will stop early once a halt record is seen. pub async fn run_multi_async( &mut self, - records: impl IntoIterator::ColumnType>>, + records: impl IntoIterator::ColumnType>>, ) -> Result<(), TestError> { for record in records.into_iter() { if let Record::Halt { .. } = record { @@ -904,7 +909,7 @@ impl Runner { /// The runner will stop early once a halt record is seen. pub fn run_multi( &mut self, - records: impl IntoIterator::ColumnType>>, + records: impl IntoIterator::ColumnType>>, ) -> Result<(), TestError> { block_on(self.run_multi_async(records)) } @@ -960,7 +965,7 @@ impl Runner { jobs: usize, ) -> Result<(), ParallelTestError> where - Fut: Future + Send + 'static, + Fut: Future, { let files = glob::glob(glob).expect("failed to read glob pattern"); let mut tasks = vec![]; @@ -1010,7 +1015,7 @@ impl Runner { jobs: usize, ) -> Result<(), ParallelTestError> where - Fut: Future + Send + 'static, + Fut: Future, { block_on(self.run_parallel_async(glob, hosts, conn_builder, jobs)) } @@ -1043,7 +1048,7 @@ impl Runner { filename: impl AsRef, col_separator: &str, validator: Validator, - column_type_validator: ColumnTypeValidator<::ColumnType>, + column_type_validator: ColumnTypeValidator<::ColumnType>, ) -> Result<(), Box> { use std::io::{Read, Seek, SeekFrom, Write}; use std::path::PathBuf; From dab5c377b4435e8e5cd53bcb632814b1dfe8ed86 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 5 Jul 2023 16:35:15 +0800 Subject: [PATCH 03/10] better impl and docs Signed-off-by: Bugen Zhao --- examples/basic/examples/basic.rs | 4 +- examples/condition/examples/condition.rs | 4 +- examples/custom_type/examples/custom_type.rs | 4 +- .../examples/file_level_sort_mode.rs | 4 +- examples/include/examples/include.rs | 4 +- examples/rowsort/examples/rowsort.rs | 4 +- .../examples/test_dir_escape.rs | 4 +- examples/validator/examples/validator.rs | 4 +- sqllogictest-bin/src/engines.rs | 37 ++++-- sqllogictest-bin/src/main.rs | 9 +- sqllogictest/src/connection.rs | 90 +++++++++++++ sqllogictest/src/harness.rs | 8 +- sqllogictest/src/lib.rs | 2 + sqllogictest/src/parser.rs | 16 ++- sqllogictest/src/runner.rs | 120 +++++------------- 15 files changed, 181 insertions(+), 133 deletions(-) create mode 100644 sqllogictest/src/connection.rs diff --git a/examples/basic/examples/basic.rs b/examples/basic/examples/basic.rs index 775ed74..812eb76 100644 --- a/examples/basic/examples/basic.rs +++ b/examples/basic/examples/basic.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType}; +use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; pub struct FakeDB; @@ -49,7 +49,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new_once(FakeDB); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/condition/examples/condition.rs b/examples/condition/examples/condition.rs index e62cd0a..ead96cb 100644 --- a/examples/condition/examples/condition.rs +++ b/examples/condition/examples/condition.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType}; +use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; pub struct FakeDB { engine_name: &'static str, @@ -43,7 +43,7 @@ impl sqllogictest::DB for FakeDB { fn main() { for engine_name in ["risinglight", "otherdb"] { - let mut tester = sqllogictest::Runner::new_once(FakeDB { engine_name }); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB { engine_name })); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/custom_type/examples/custom_type.rs b/examples/custom_type/examples/custom_type.rs index c8c31ab..13123fb 100644 --- a/examples/custom_type/examples/custom_type.rs +++ b/examples/custom_type/examples/custom_type.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{strict_column_validator, ColumnType, DBOutput}; +use sqllogictest::{strict_column_validator, ColumnType, DBOutput, MakeWith}; pub struct FakeDB; @@ -67,7 +67,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new_once(FakeDB); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); tester.with_column_validator(strict_column_validator); let mut filename = PathBuf::from(file!()); diff --git a/examples/file_level_sort_mode/examples/file_level_sort_mode.rs b/examples/file_level_sort_mode/examples/file_level_sort_mode.rs index d08e65a..ab89a33 100644 --- a/examples/file_level_sort_mode/examples/file_level_sort_mode.rs +++ b/examples/file_level_sort_mode/examples/file_level_sort_mode.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType}; +use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; pub struct FakeDB; @@ -41,7 +41,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new_once(FakeDB); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/include/examples/include.rs b/examples/include/examples/include.rs index b10c583..eba7393 100644 --- a/examples/include/examples/include.rs +++ b/examples/include/examples/include.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType}; +use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; pub struct FakeDB; @@ -44,7 +44,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new_once(FakeDB); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/rowsort/examples/rowsort.rs b/examples/rowsort/examples/rowsort.rs index b2b3302..c8c16dd 100644 --- a/examples/rowsort/examples/rowsort.rs +++ b/examples/rowsort/examples/rowsort.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType}; +use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; pub struct FakeDB; @@ -41,7 +41,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new_once(FakeDB); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/test_dir_escape/examples/test_dir_escape.rs b/examples/test_dir_escape/examples/test_dir_escape.rs index 6982554..0156e2e 100644 --- a/examples/test_dir_escape/examples/test_dir_escape.rs +++ b/examples/test_dir_escape/examples/test_dir_escape.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType}; +use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; pub struct FakeDB; @@ -28,7 +28,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new_once(FakeDB); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); // enable `__TEST_DIR__` override tester.enable_testdir(); diff --git a/examples/validator/examples/validator.rs b/examples/validator/examples/validator.rs index 6fd0d67..0347099 100644 --- a/examples/validator/examples/validator.rs +++ b/examples/validator/examples/validator.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType}; +use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; pub struct FakeDB; @@ -28,7 +28,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new_once(FakeDB); + let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); // Validator will always return true. tester.with_validator(|_, _| true); diff --git a/sqllogictest-bin/src/engines.rs b/sqllogictest-bin/src/engines.rs index 3223bbe..f14add6 100644 --- a/sqllogictest-bin/src/engines.rs +++ b/sqllogictest-bin/src/engines.rs @@ -23,7 +23,7 @@ pub enum EngineConfig { External(String), } -enum Engines { +pub(crate) enum Engines { Postgres(PostgresSimple), PostgresExtended(PostgresExtended), External(ExternalDriver), @@ -48,12 +48,21 @@ impl From<&DBConfig> for PostgresConfig { } } -pub(super) async fn connect(engine: &EngineConfig, config: &DBConfig) -> Result { +pub(crate) async fn connect( + engine: &EngineConfig, + config: &DBConfig, +) -> Result { Ok(match engine { - EngineConfig::Postgres => Engines::Postgres(PostgresSimple::connect(config.into()).await?), - EngineConfig::PostgresExtended => { - Engines::PostgresExtended(PostgresExtended::connect(config.into()).await?) - } + EngineConfig::Postgres => Engines::Postgres( + PostgresSimple::connect(config.into()) + .await + .map_err(|e| EnginesError(e.into()))?, + ), + EngineConfig::PostgresExtended => Engines::PostgresExtended( + PostgresExtended::connect(config.into()) + .await + .map_err(|e| EnginesError(e.into()))?, + ), EngineConfig::External(cmd_tmpl) => { let (host, port) = config.random_addr(); let cmd_str = cmd_tmpl @@ -64,21 +73,25 @@ pub(super) async fn connect(engine: &EngineConfig, config: &DBConfig) -> Result< .replace("{pass}", &config.pass); let mut cmd = Command::new("bash"); cmd.args(["-c", &cmd_str]); - Engines::External(ExternalDriver::connect(cmd).await?) + Engines::External( + ExternalDriver::connect(cmd) + .await + .map_err(|e| EnginesError(e.into()))?, + ) } }) } #[derive(Debug)] -struct AnyhowError(anyhow::Error); +pub(crate) struct EnginesError(anyhow::Error); -impl Display for AnyhowError { +impl Display for EnginesError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl std::error::Error for AnyhowError { +impl std::error::Error for EnginesError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { self.0.source() } @@ -96,10 +109,10 @@ impl Engines { #[async_trait] impl AsyncDB for Engines { - type Error = AnyhowError; + type Error = EnginesError; type ColumnType = DefaultColumnType; async fn run(&mut self, sql: &str) -> Result, Self::Error> { - self.run(sql).await.map_err(AnyhowError) + self.run(sql).await.map_err(EnginesError) } } diff --git a/sqllogictest-bin/src/main.rs b/sqllogictest-bin/src/main.rs index 4efb5e1..aacbd64 100644 --- a/sqllogictest-bin/src/main.rs +++ b/sqllogictest-bin/src/main.rs @@ -369,8 +369,7 @@ async fn run_serial( let mut failed_case = vec![]; for file in files { - let engine = engines::connect(engine, &config).await?; - let mut runner = Runner::new_once(engine); + let mut runner = Runner::new(|| engines::connect(engine, &config)); for label in labels { runner.add_label(label); } @@ -418,8 +417,7 @@ async fn update_test_files( format: bool, ) -> Result<()> { for file in files { - let engine = engines::connect(engine, &config).await?; - let runner = Runner::new_once(engine); + let runner = Runner::new(|| engines::connect(engine, &config)); if let Err(e) = update_test_file(&mut std::io::stdout(), runner, &file, format).await { { @@ -443,8 +441,7 @@ async fn connect_and_run_test_file( config: DBConfig, labels: &[String], ) -> Result { - let engine = engines::connect(engine, &config).await?; - let mut runner = Runner::new_once(engine); + let mut runner = Runner::new(|| engines::connect(engine, &config)); for label in labels { runner.add_label(label); } diff --git a/sqllogictest/src/connection.rs b/sqllogictest/src/connection.rs new file mode 100644 index 0000000..99aeb07 --- /dev/null +++ b/sqllogictest/src/connection.rs @@ -0,0 +1,90 @@ +use std::collections::HashMap; +use std::future::IntoFuture; + +use futures::Future; + +use crate::{AsyncDB, Connection as ConnectionName, DBOutput}; + +/// Trait for making connections to an [`AsyncDB`]. +/// +/// This is introduced to allow querying the database with different connections +/// (then generally different sessions) in a single test file with `connection` records. +pub trait MakeConnection { + /// The database type. + type Conn: AsyncDB; + /// The future returned by [`MakeConnection::make`]. + type MakeFuture: Future::Error>>; + + /// Creates a new connection to the database. + fn make(&mut self) -> Self::MakeFuture; +} + +/// Make connections directly from a closure returning a future. +impl MakeConnection for F +where + F: FnMut() -> Fut, + Fut: IntoFuture>, +{ + type Conn = D; + type MakeFuture = Fut::IntoFuture; + + fn make(&mut self) -> Self::MakeFuture { + self().into_future() + } +} + +/// Make connections with a synchronous infallible function. +#[derive(Debug)] +pub struct MakeWith(pub F); + +impl D, D: AsyncDB> MakeConnection for MakeWith { + type Conn = D; + type MakeFuture = futures::future::Ready>; + + fn make(&mut self) -> Self::MakeFuture { + futures::future::ready(Ok((self.0)())) + } +} + +/// Connections established in a [`Runner`](crate::Runner). +pub(crate) struct Connections { + make_conn: M, + conns: HashMap, +} + +impl Connections { + pub fn new(make_conn: M) -> Self { + Connections { + make_conn, + conns: HashMap::new(), + } + } + + /// Get a connection by name. Make a new connection if it doesn't exist. + pub async fn get( + &mut self, + name: ConnectionName, + ) -> Result<&mut M::Conn, ::Error> { + use std::collections::hash_map::Entry; + + let conn = match self.conns.entry(name) { + Entry::Occupied(o) => o.into_mut(), + Entry::Vacant(v) => { + let conn = self.make_conn.make().await?; + v.insert(conn) + } + }; + + Ok(conn) + } + + /// Run a SQL statement on the default connection. + /// + /// This is a shortcut for calling `get(Default)` then `run`. + pub async fn run_default( + &mut self, + sql: &str, + ) -> Result::ColumnType>, ::Error> { + self.get(ConnectionName::Default).await?.run(sql).await + } +} diff --git a/sqllogictest/src/harness.rs b/sqllogictest/src/harness.rs index e52aa05..8d0ba65 100644 --- a/sqllogictest/src/harness.rs +++ b/sqllogictest/src/harness.rs @@ -3,7 +3,7 @@ use std::path::Path; pub use glob::glob; pub use libtest_mimic::{run, Arguments, Failed, Trial}; -use crate::{AsyncDB, MakeOnce, Runner}; +use crate::{MakeConnection, Runner}; /// * `db_fn`: `fn() -> sqllogictest::AsyncDB` /// * `pattern`: The glob used to match against and select each file to be tested. It is relative to @@ -19,7 +19,7 @@ macro_rules! harness { let path = entry.expect("failed to read glob entry"); tests.push($crate::harness::Trial::test( path.to_str().unwrap().to_string(), - move || $crate::harness::test(&path, $db_fn()), + move || $crate::harness::test(&path, $crate::MakeWith($db_fn)), )); } @@ -32,8 +32,8 @@ macro_rules! harness { }; } -pub fn test(filename: impl AsRef, db: impl AsyncDB) -> Result<(), Failed> { - let mut tester = Runner::new(MakeOnce::new(db)); +pub fn test(filename: impl AsRef, make_conn: impl MakeConnection) -> Result<(), Failed> { + let mut tester = Runner::new(make_conn); tester.run_file(filename)?; Ok(()) } diff --git a/sqllogictest/src/lib.rs b/sqllogictest/src/lib.rs index 1cc5441..ac80309 100644 --- a/sqllogictest/src/lib.rs +++ b/sqllogictest/src/lib.rs @@ -35,10 +35,12 @@ //! ``` pub mod column_type; +pub mod connection; pub mod parser; pub mod runner; pub use self::column_type::*; +pub use self::connection::*; pub use self::parser::*; pub use self::runner::*; diff --git a/sqllogictest/src/parser.rs b/sqllogictest/src/parser.rs index 2b91448..23077fe 100644 --- a/sqllogictest/src/parser.rs +++ b/sqllogictest/src/parser.rs @@ -1,6 +1,5 @@ //! Sqllogictest parser. -use std::collections::HashSet; use std::fmt; use std::path::Path; use std::sync::Arc; @@ -137,7 +136,9 @@ pub enum Record { loc: Location, threshold: u64, }, + /// Condition statements, including `onlyif` and `skipif`. Condition(Condition), + /// Connection statements to specify the connection to use for the following statement. Connection(Connection), Comment(Vec), Newline, @@ -304,18 +305,21 @@ pub enum Condition { impl Condition { /// Evaluate condition on given `label`, returns whether to skip this record. - pub(crate) fn should_skip(&self, labels: &HashSet) -> bool { + pub(crate) fn should_skip<'a>(&'a self, labels: impl IntoIterator) -> bool { match self { - Condition::OnlyIf { label } => !labels.contains(label), - Condition::SkipIf { label } => labels.contains(label), + Condition::OnlyIf { label } => !labels.into_iter().contains(&label.as_str()), + Condition::SkipIf { label } => labels.into_iter().contains(&label.as_str()), } } } +/// The connection to use for the following statement. #[derive(Default, Debug, PartialEq, Eq, Hash, Clone)] pub enum Connection { + /// The default connection if not specified or if the name is "default". #[default] Default, + /// A named connection. Named(String), } @@ -495,8 +499,8 @@ fn parse_inner(loc: &Location, script: &str) -> Result { - let conn = Connection::new(label); + ["connection", name] => { + let conn = Connection::new(name); connection = conn.clone(); records.push(Record::Connection(conn)); } diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 0c3f287..afc88ae 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -1,6 +1,6 @@ //! Sqllogictest runner. -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::fmt::{Debug, Display}; use std::path::Path; use std::sync::Arc; @@ -18,7 +18,7 @@ use similar::{Change, ChangeTag, TextDiff}; use tempfile::{tempdir, TempDir}; use crate::parser::*; -use crate::ColumnType; +use crate::{ColumnType, Connections, MakeConnection}; #[derive(Debug, Clone)] pub enum RecordOutput { @@ -450,72 +450,6 @@ pub fn strict_column_validator(actual: &Vec, expected: &Vec .any(|(actual_column, expected_column)| actual_column != expected_column) } -pub trait MakeConnection { - type Conn: AsyncDB; - type MakeFuture: Future::Error>>; - - fn make(&mut self) -> Self::MakeFuture; -} - -impl MakeConnection for F -where - F: FnMut() -> Fut, - Fut: Future>, -{ - type Conn = D; - type MakeFuture = Fut; - - fn make(&mut self) -> Self::MakeFuture { - self() - } -} - -pub struct MakeOnce(Option); - -impl MakeOnce { - pub fn new(db: D) -> Self { - MakeOnce(Some(db)) - } -} - -impl MakeConnection for MakeOnce { - type Conn = D; - type MakeFuture = futures::future::Ready>; - - fn make(&mut self) -> Self::MakeFuture { - futures::future::ready(Ok(self.0.take().expect("MakeOnce can only be used once"))) - } -} - -struct Connections { - make_conn: M, - conns: HashMap, -} - -impl Connections { - fn new(make_conn: M) -> Self { - Connections { - make_conn, - conns: HashMap::new(), - } - } - - async fn get(&mut self, name: Connection) -> Result<&mut M::Conn, ::Error> { - if !self.conns.contains_key(&name) { - let conn = self.make_conn.make().await?; - self.conns.insert(name.clone(), conn); - } - Ok(self.conns.get_mut(&name).unwrap()) - } - - async fn run_default( - &mut self, - sql: &str, - ) -> Result::ColumnType>, ::Error> { - self.get(Connection::Default).await?.run(sql).await - } -} - /// Sqllogictest runner. pub struct Runner { conn: Connections, @@ -530,14 +464,10 @@ pub struct Runner { labels: HashSet, } -impl Runner> { - pub fn new_once(db: D) -> Self { - Self::new(MakeOnce::new(db)) - } -} - impl Runner { - /// Create a new test runner on the database. + /// Create a new test runner on the database, with the given connection maker. + /// + /// See [`MakeConnection`] for more details. pub fn new(make_conn: M) -> Self { Runner { validator: default_validator, @@ -545,7 +475,7 @@ impl Runner { testdir: None, sort_mode: None, hash_threshold: 0, - labels: HashSet::new(), // TODO + labels: HashSet::new(), conn: Connections::new(make_conn), } } @@ -582,12 +512,26 @@ impl Runner { &mut self, record: Record<::ColumnType>, ) -> RecordOutput<::ColumnType> { + /// Returns whether we should skip this record, according to given `conditions`. + fn should_skip( + labels: &HashSet, + engine_name: &str, + conditions: &[Condition], + ) -> bool { + conditions.iter().any(|c| { + c.should_skip( + labels + .iter() + .map(|l| l.as_str()) + // attach the engine name to the labels + .chain(Some(engine_name).filter(|n| !n.is_empty())), + ) + }) + } + match record { - Record::Statement { conditions, .. } if self.should_skip(&conditions) => { - RecordOutput::Nothing - } Record::Statement { - conditions: _, + conditions, connection, sql, @@ -607,6 +551,9 @@ impl Runner { } } }; + if should_skip(&self.labels, conn.engine_name(), &conditions) { + return RecordOutput::Nothing; + } let ret = conn.run(&sql).await; match ret { @@ -626,11 +573,8 @@ impl Runner { }, } } - Record::Query { conditions, .. } if self.should_skip(&conditions) => { - RecordOutput::Nothing - } Record::Query { - conditions: _, + conditions, connection, sql, sort_mode, @@ -656,6 +600,9 @@ impl Runner { } } }; + if should_skip(&self.labels, conn.engine_name(), &conditions) { + return RecordOutput::Nothing; + } let (types, mut rows) = match conn.run(&sql).await { Ok(out) => match out { @@ -1029,11 +976,6 @@ impl Runner { } } - /// Returns whether we should skip this record, according to given `conditions`. - fn should_skip(&self, conditions: &[Condition]) -> bool { - conditions.iter().any(|c| c.should_skip(&self.labels)) - } - /// Updates a test file with the output produced by a Database. It is an utility function /// wrapping [`update_test_file_with_runner`]. /// From 05eb08330af37d1e4af2c8fdcbdea6d732a11ed2 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 5 Jul 2023 16:44:36 +0800 Subject: [PATCH 04/10] add example Signed-off-by: Bugen Zhao --- Cargo.lock | 7 +++ examples/connection/Cargo.toml | 8 ++++ examples/connection/connection.slt | 38 ++++++++++++++++ examples/connection/examples/connection.rs | 53 ++++++++++++++++++++++ 4 files changed, 106 insertions(+) create mode 100644 examples/connection/Cargo.toml create mode 100644 examples/connection/connection.slt create mode 100644 examples/connection/examples/connection.rs diff --git a/Cargo.lock b/Cargo.lock index fb46f1f..52b98ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,6 +314,13 @@ dependencies = [ "sqllogictest", ] +[[package]] +name = "connection" +version = "0.1.0" +dependencies = [ + "sqllogictest", +] + [[package]] name = "console" version = "0.15.5" diff --git a/examples/connection/Cargo.toml b/examples/connection/Cargo.toml new file mode 100644 index 0000000..a6d7eed --- /dev/null +++ b/examples/connection/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "connection" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +sqllogictest = { path = "../../sqllogictest" } diff --git a/examples/connection/connection.slt b/examples/connection/connection.slt new file mode 100644 index 0000000..b6c12c4 --- /dev/null +++ b/examples/connection/connection.slt @@ -0,0 +1,38 @@ +query I +select counter() +---- +1 + +query I +select counter() +---- +2 + +connection another +query I +select counter() +---- +1 + +connection default +query I +select counter() +---- +3 + +connection another +query I +select counter() +---- +2 + +connection AnOtHeR +query I +select counter() +---- +1 + +query I +select counter() +---- +4 diff --git a/examples/connection/examples/connection.rs b/examples/connection/examples/connection.rs new file mode 100644 index 0000000..bd969d4 --- /dev/null +++ b/examples/connection/examples/connection.rs @@ -0,0 +1,53 @@ +use std::path::PathBuf; + +use sqllogictest::{DBOutput, DefaultColumnType}; + +pub struct FakeDB { + counter: u64, +} + +impl FakeDB { + #[allow(clippy::unused_async)] + async fn connect() -> Result { + Ok(Self { counter: 0 }) + } +} + +#[derive(Debug)] +pub struct FakeDBError; + +impl std::fmt::Display for FakeDBError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl std::error::Error for FakeDBError {} + +impl sqllogictest::DB for FakeDB { + type Error = FakeDBError; + type ColumnType = DefaultColumnType; + + fn run(&mut self, sql: &str) -> Result, FakeDBError> { + if sql == "select counter()" { + self.counter += 1; + Ok(DBOutput::Rows { + types: vec![DefaultColumnType::Integer], + rows: vec![vec![self.counter.to_string()]], + }) + } else { + Err(FakeDBError) + } + } +} + +fn main() { + let mut tester = sqllogictest::Runner::new(FakeDB::connect); + + let mut filename = PathBuf::from(file!()); + filename.pop(); + filename.pop(); + filename.push("connection.slt"); + + tester.run_file(filename).unwrap(); +} From 48925cb4ac412d03203cf2606166c69d90493adb Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 5 Jul 2023 16:52:55 +0800 Subject: [PATCH 05/10] add more comments Signed-off-by: Bugen Zhao --- examples/connection/connection.slt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/connection/connection.slt b/examples/connection/connection.slt index b6c12c4..d1010c4 100644 --- a/examples/connection/connection.slt +++ b/examples/connection/connection.slt @@ -14,6 +14,7 @@ select counter() ---- 1 +# `default` is the name of the default connection if not specified connection default query I select counter() @@ -26,6 +27,7 @@ select counter() ---- 2 +# connection names are case sensitive connection AnOtHeR query I select counter() From d9017b3863ec945125527a84777b918682b8d5ce Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Thu, 6 Jul 2023 17:34:03 +0800 Subject: [PATCH 06/10] add changelog and bump version Signed-off-by: Bugen Zhao --- CHANGELOG.md | 9 ++++++++- Cargo.lock | 6 +++--- Cargo.toml | 2 +- sqllogictest-bin/Cargo.toml | 4 ++-- sqllogictest-engines/Cargo.toml | 2 +- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5df2db..e0b9bb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,16 @@ All notable changes to this project will be documented in this file. -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.15.0] - 2023-07-06 + +* Allow multiple connections to the database in a single test case, which is useful for testing the transaction behavior. This can be achieved by attaching a `connection foo` record before the query or statement. + - (parser) Add `Record::Connection`. + - (runner) **Breaking change**: Since the runner may establish multiple connections at runtime, `Runner::new` now takes a `impl MakeConnection`, which is usually a closure that returns a try-future of the `AsyncDB` instance. + - (bin) The connection to the database is now established lazily on the first query or statement. + ## [0.14.0] - 2023-06-08 * We enhanced how `skipif` and `onlyif` works. Previously it checks against `DB::engine_name()`, and `sqllogictest-bin` didn't implement it. diff --git a/Cargo.lock b/Cargo.lock index 52b98ce..01915fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1431,7 +1431,7 @@ dependencies = [ [[package]] name = "sqllogictest" -version = "0.14.0" +version = "0.15.0" dependencies = [ "async-trait", "educe", @@ -1452,7 +1452,7 @@ dependencies = [ [[package]] name = "sqllogictest-bin" -version = "0.14.0" +version = "0.15.0" dependencies = [ "anyhow", "async-trait", @@ -1473,7 +1473,7 @@ dependencies = [ [[package]] name = "sqllogictest-engines" -version = "0.14.0" +version = "0.15.0" dependencies = [ "async-trait", "bytes", diff --git a/Cargo.toml b/Cargo.toml index b5e1aaf..2ac12e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = ["examples/*", "sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"] [workspace.package] -version = "0.14.0" +version = "0.15.0" edition = "2021" homepage = "https://github.com/risinglightdb/sqllogictest-rs" keywords = ["sql", "database", "parser", "cli"] diff --git a/sqllogictest-bin/Cargo.toml b/sqllogictest-bin/Cargo.toml index d3457de..9cc09d7 100644 --- a/sqllogictest-bin/Cargo.toml +++ b/sqllogictest-bin/Cargo.toml @@ -24,8 +24,8 @@ glob = "0.3" itertools = "0.10" quick-junit = { version = "0.2" } rand = "0.8" -sqllogictest = { path = "../sqllogictest", version = "0.14" } -sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.14" } +sqllogictest = { path = "../sqllogictest" } +sqllogictest-engines = { path = "../sqllogictest-engines" } tokio = { version = "1", features = [ "rt", "rt-multi-thread", diff --git a/sqllogictest-engines/Cargo.toml b/sqllogictest-engines/Cargo.toml index 846c0f6..4d1b048 100644 --- a/sqllogictest-engines/Cargo.toml +++ b/sqllogictest-engines/Cargo.toml @@ -19,7 +19,7 @@ postgres-types = { version = "0.2.3", features = ["derive", "with-chrono-0_4"] } rust_decimal = { version = "1.7.0", features = ["tokio-pg"] } serde = { version = "1", features = ["derive"] } serde_json = "1" -sqllogictest = { path = "../sqllogictest", version = "0.14" } +sqllogictest = { path = "../sqllogictest" } thiserror = "1" tokio = { version = "1", features = [ "rt", From 5d9e210336e1d47e67de31f337b4002b4926af0d Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 11 Jul 2023 13:24:56 +0800 Subject: [PATCH 07/10] fix dependency version Signed-off-by: Bugen Zhao --- sqllogictest-bin/Cargo.toml | 4 ++-- sqllogictest-engines/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sqllogictest-bin/Cargo.toml b/sqllogictest-bin/Cargo.toml index 9cc09d7..5d60060 100644 --- a/sqllogictest-bin/Cargo.toml +++ b/sqllogictest-bin/Cargo.toml @@ -24,8 +24,8 @@ glob = "0.3" itertools = "0.10" quick-junit = { version = "0.2" } rand = "0.8" -sqllogictest = { path = "../sqllogictest" } -sqllogictest-engines = { path = "../sqllogictest-engines" } +sqllogictest = { path = "../sqllogictest", version = "0.15" } +sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.15" } tokio = { version = "1", features = [ "rt", "rt-multi-thread", diff --git a/sqllogictest-engines/Cargo.toml b/sqllogictest-engines/Cargo.toml index 4d1b048..a74d44f 100644 --- a/sqllogictest-engines/Cargo.toml +++ b/sqllogictest-engines/Cargo.toml @@ -19,7 +19,7 @@ postgres-types = { version = "0.2.3", features = ["derive", "with-chrono-0_4"] } rust_decimal = { version = "1.7.0", features = ["tokio-pg"] } serde = { version = "1", features = ["derive"] } serde_json = "1" -sqllogictest = { path = "../sqllogictest" } +sqllogictest = { path = "../sqllogictest", version = "0.15" } thiserror = "1" tokio = { version = "1", features = [ "rt", From 8ce8d26f6a5f61671720a48a2e35f27e1559303b Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 11 Jul 2023 13:28:27 +0800 Subject: [PATCH 08/10] remove MakeWith Signed-off-by: Bugen Zhao --- examples/basic/examples/basic.rs | 4 ++-- examples/condition/examples/condition.rs | 4 ++-- examples/custom_type/examples/custom_type.rs | 4 ++-- .../examples/file_level_sort_mode.rs | 4 ++-- examples/include/examples/include.rs | 4 ++-- examples/rowsort/examples/rowsort.rs | 4 ++-- .../test_dir_escape/examples/test_dir_escape.rs | 4 ++-- examples/validator/examples/validator.rs | 4 ++-- sqllogictest/src/connection.rs | 13 ------------- sqllogictest/src/harness.rs | 2 +- 10 files changed, 17 insertions(+), 30 deletions(-) diff --git a/examples/basic/examples/basic.rs b/examples/basic/examples/basic.rs index 812eb76..389e59c 100644 --- a/examples/basic/examples/basic.rs +++ b/examples/basic/examples/basic.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; +use sqllogictest::{DBOutput, DefaultColumnType}; pub struct FakeDB; @@ -49,7 +49,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) }); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/condition/examples/condition.rs b/examples/condition/examples/condition.rs index ead96cb..8a0d107 100644 --- a/examples/condition/examples/condition.rs +++ b/examples/condition/examples/condition.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; +use sqllogictest::{DBOutput, DefaultColumnType}; pub struct FakeDB { engine_name: &'static str, @@ -43,7 +43,7 @@ impl sqllogictest::DB for FakeDB { fn main() { for engine_name in ["risinglight", "otherdb"] { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB { engine_name })); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB { engine_name }) }); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/custom_type/examples/custom_type.rs b/examples/custom_type/examples/custom_type.rs index 13123fb..ac9b14f 100644 --- a/examples/custom_type/examples/custom_type.rs +++ b/examples/custom_type/examples/custom_type.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{strict_column_validator, ColumnType, DBOutput, MakeWith}; +use sqllogictest::{strict_column_validator, ColumnType, DBOutput}; pub struct FakeDB; @@ -67,7 +67,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) }); tester.with_column_validator(strict_column_validator); let mut filename = PathBuf::from(file!()); diff --git a/examples/file_level_sort_mode/examples/file_level_sort_mode.rs b/examples/file_level_sort_mode/examples/file_level_sort_mode.rs index ab89a33..4a2a66f 100644 --- a/examples/file_level_sort_mode/examples/file_level_sort_mode.rs +++ b/examples/file_level_sort_mode/examples/file_level_sort_mode.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; +use sqllogictest::{DBOutput, DefaultColumnType}; pub struct FakeDB; @@ -41,7 +41,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) }); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/include/examples/include.rs b/examples/include/examples/include.rs index eba7393..4ab676b 100644 --- a/examples/include/examples/include.rs +++ b/examples/include/examples/include.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; +use sqllogictest::{DBOutput, DefaultColumnType}; pub struct FakeDB; @@ -44,7 +44,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) }); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/rowsort/examples/rowsort.rs b/examples/rowsort/examples/rowsort.rs index c8c16dd..baf9e2c 100644 --- a/examples/rowsort/examples/rowsort.rs +++ b/examples/rowsort/examples/rowsort.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; +use sqllogictest::{DBOutput, DefaultColumnType}; pub struct FakeDB; @@ -41,7 +41,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) }); let mut filename = PathBuf::from(file!()); filename.pop(); diff --git a/examples/test_dir_escape/examples/test_dir_escape.rs b/examples/test_dir_escape/examples/test_dir_escape.rs index 0156e2e..8a2a81c 100644 --- a/examples/test_dir_escape/examples/test_dir_escape.rs +++ b/examples/test_dir_escape/examples/test_dir_escape.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; +use sqllogictest::{DBOutput, DefaultColumnType}; pub struct FakeDB; @@ -28,7 +28,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) }); // enable `__TEST_DIR__` override tester.enable_testdir(); diff --git a/examples/validator/examples/validator.rs b/examples/validator/examples/validator.rs index 0347099..e8235a0 100644 --- a/examples/validator/examples/validator.rs +++ b/examples/validator/examples/validator.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use sqllogictest::{DBOutput, DefaultColumnType, MakeWith}; +use sqllogictest::{DBOutput, DefaultColumnType}; pub struct FakeDB; @@ -28,7 +28,7 @@ impl sqllogictest::DB for FakeDB { } fn main() { - let mut tester = sqllogictest::Runner::new(MakeWith(|| FakeDB)); + let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) }); // Validator will always return true. tester.with_validator(|_, _| true); diff --git a/sqllogictest/src/connection.rs b/sqllogictest/src/connection.rs index 99aeb07..9d794df 100644 --- a/sqllogictest/src/connection.rs +++ b/sqllogictest/src/connection.rs @@ -33,19 +33,6 @@ where } } -/// Make connections with a synchronous infallible function. -#[derive(Debug)] -pub struct MakeWith(pub F); - -impl D, D: AsyncDB> MakeConnection for MakeWith { - type Conn = D; - type MakeFuture = futures::future::Ready>; - - fn make(&mut self) -> Self::MakeFuture { - futures::future::ready(Ok((self.0)())) - } -} - /// Connections established in a [`Runner`](crate::Runner). pub(crate) struct Connections { make_conn: M, diff --git a/sqllogictest/src/harness.rs b/sqllogictest/src/harness.rs index 8d0ba65..4875278 100644 --- a/sqllogictest/src/harness.rs +++ b/sqllogictest/src/harness.rs @@ -19,7 +19,7 @@ macro_rules! harness { let path = entry.expect("failed to read glob entry"); tests.push($crate::harness::Trial::test( path.to_str().unwrap().to_string(), - move || $crate::harness::test(&path, $crate::MakeWith($db_fn)), + move || $crate::harness::test(&path, || async { Ok($db_fn()) }), )); } From ed600a8b8f8329babbb9d4c3759fcd4b7082b74a Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 11 Jul 2023 13:29:21 +0800 Subject: [PATCH 09/10] refine test Signed-off-by: Bugen Zhao --- examples/connection/connection.slt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/connection/connection.slt b/examples/connection/connection.slt index d1010c4..0dbeae0 100644 --- a/examples/connection/connection.slt +++ b/examples/connection/connection.slt @@ -34,6 +34,7 @@ select counter() ---- 1 +# connection only works for one record, the next one will use `default` query I select counter() ---- From ffd63b6fc8329328bf143cb7a7ab7d1a2ae23d23 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 11 Jul 2023 13:39:35 +0800 Subject: [PATCH 10/10] refactor generic Signed-off-by: Bugen Zhao --- sqllogictest-bin/src/main.rs | 6 +++--- sqllogictest/src/connection.rs | 16 +++++--------- sqllogictest/src/runner.rs | 39 +++++++++++++--------------------- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/sqllogictest-bin/src/main.rs b/sqllogictest-bin/src/main.rs index aacbd64..df5d0ea 100644 --- a/sqllogictest-bin/src/main.rs +++ b/sqllogictest-bin/src/main.rs @@ -454,7 +454,7 @@ async fn connect_and_run_test_file( /// information. async fn run_test_file( out: &mut T, - mut runner: Runner, + mut runner: Runner, filename: impl AsRef, ) -> Result { let filename = filename.as_ref(); @@ -557,7 +557,7 @@ fn finish_test_file( /// progress information. async fn update_test_file( out: &mut T, - mut runner: Runner, + mut runner: Runner, filename: impl AsRef, format: bool, ) -> Result<()> { @@ -712,7 +712,7 @@ async fn update_test_file( async fn update_record( outfile: &mut File, - runner: &mut Runner, + runner: &mut Runner, record: Record<::ColumnType>, format: bool, ) -> Result<()> { diff --git a/sqllogictest/src/connection.rs b/sqllogictest/src/connection.rs index 9d794df..22bc6eb 100644 --- a/sqllogictest/src/connection.rs +++ b/sqllogictest/src/connection.rs @@ -34,12 +34,12 @@ where } /// Connections established in a [`Runner`](crate::Runner). -pub(crate) struct Connections { +pub(crate) struct Connections { make_conn: M, - conns: HashMap, + conns: HashMap, } -impl Connections { +impl> Connections { pub fn new(make_conn: M) -> Self { Connections { make_conn, @@ -48,10 +48,7 @@ impl Connections { } /// Get a connection by name. Make a new connection if it doesn't exist. - pub async fn get( - &mut self, - name: ConnectionName, - ) -> Result<&mut M::Conn, ::Error> { + pub async fn get(&mut self, name: ConnectionName) -> Result<&mut D, D::Error> { use std::collections::hash_map::Entry; let conn = match self.conns.entry(name) { @@ -68,10 +65,7 @@ impl Connections { /// Run a SQL statement on the default connection. /// /// This is a shortcut for calling `get(Default)` then `run`. - pub async fn run_default( - &mut self, - sql: &str, - ) -> Result::ColumnType>, ::Error> { + pub async fn run_default(&mut self, sql: &str) -> Result, D::Error> { self.get(ConnectionName::Default).await?.run(sql).await } } diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index afc88ae..ca2e88c 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -451,11 +451,11 @@ pub fn strict_column_validator(actual: &Vec, expected: &Vec } /// Sqllogictest runner. -pub struct Runner { - conn: Connections, +pub struct Runner { + conn: Connections, // validator is used for validate if the result of query equals to expected. validator: Validator, - column_type_validator: ColumnTypeValidator<::ColumnType>, + column_type_validator: ColumnTypeValidator, testdir: Option, sort_mode: Option, /// 0 means never hashing @@ -464,7 +464,7 @@ pub struct Runner { labels: HashSet, } -impl Runner { +impl> Runner { /// Create a new test runner on the database, with the given connection maker. /// /// See [`MakeConnection`] for more details. @@ -497,10 +497,7 @@ impl Runner { self.validator = validator; } - pub fn with_column_validator( - &mut self, - validator: ColumnTypeValidator<::ColumnType>, - ) { + pub fn with_column_validator(&mut self, validator: ColumnTypeValidator) { self.column_type_validator = validator; } @@ -510,8 +507,8 @@ impl Runner { pub async fn apply_record( &mut self, - record: Record<::ColumnType>, - ) -> RecordOutput<::ColumnType> { + record: Record, + ) -> RecordOutput { /// Returns whether we should skip this record, according to given `conditions`. fn should_skip( labels: &HashSet, @@ -651,7 +648,7 @@ impl Runner { } } Record::Sleep { duration, .. } => { - ::sleep(duration).await; + D::sleep(duration).await; RecordOutput::Nothing } Record::Control(control) => match control { @@ -676,10 +673,7 @@ impl Runner { } /// Run a single record. - pub async fn run_async( - &mut self, - record: Record<::ColumnType>, - ) -> Result<(), TestError> { + pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> { tracing::debug!(?record, "testing"); match (record.clone(), self.apply_record(record).await) { @@ -828,10 +822,7 @@ impl Runner { } /// Run a single record. - pub fn run( - &mut self, - record: Record<::ColumnType>, - ) -> Result<(), TestError> { + pub fn run(&mut self, record: Record) -> Result<(), TestError> { futures::executor::block_on(self.run_async(record)) } @@ -840,7 +831,7 @@ impl Runner { /// The runner will stop early once a halt record is seen. pub async fn run_multi_async( &mut self, - records: impl IntoIterator::ColumnType>>, + records: impl IntoIterator>, ) -> Result<(), TestError> { for record in records.into_iter() { if let Record::Halt { .. } = record { @@ -856,7 +847,7 @@ impl Runner { /// The runner will stop early once a halt record is seen. pub fn run_multi( &mut self, - records: impl IntoIterator::ColumnType>>, + records: impl IntoIterator>, ) -> Result<(), TestError> { block_on(self.run_multi_async(records)) } @@ -912,7 +903,7 @@ impl Runner { jobs: usize, ) -> Result<(), ParallelTestError> where - Fut: Future, + Fut: Future, { let files = glob::glob(glob).expect("failed to read glob pattern"); let mut tasks = vec![]; @@ -962,7 +953,7 @@ impl Runner { jobs: usize, ) -> Result<(), ParallelTestError> where - Fut: Future, + Fut: Future, { block_on(self.run_parallel_async(glob, hosts, conn_builder, jobs)) } @@ -990,7 +981,7 @@ impl Runner { filename: impl AsRef, col_separator: &str, validator: Validator, - column_type_validator: ColumnTypeValidator<::ColumnType>, + column_type_validator: ColumnTypeValidator, ) -> Result<(), Box> { use std::io::{Read, Seek, SeekFrom, Write}; use std::path::PathBuf;