Skip to content

Commit

Permalink
refactor generic
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <i@bugenzhao.com>
  • Loading branch information
BugenZhao committed Jul 11, 2023
1 parent ed600a8 commit ffd63b6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 38 deletions.
6 changes: 3 additions & 3 deletions sqllogictest-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ async fn connect_and_run_test_file(
/// information.
async fn run_test_file<T: std::io::Write, M: MakeConnection>(
out: &mut T,
mut runner: Runner<M>,
mut runner: Runner<M::Conn, M>,
filename: impl AsRef<Path>,
) -> Result<Duration> {
let filename = filename.as_ref();
Expand Down Expand Up @@ -557,7 +557,7 @@ fn finish_test_file<T: std::io::Write>(
/// progress information.
async fn update_test_file<T: std::io::Write, M: MakeConnection>(
out: &mut T,
mut runner: Runner<M>,
mut runner: Runner<M::Conn, M>,
filename: impl AsRef<Path>,
format: bool,
) -> Result<()> {
Expand Down Expand Up @@ -712,7 +712,7 @@ async fn update_test_file<T: std::io::Write, M: MakeConnection>(

async fn update_record<M: MakeConnection>(
outfile: &mut File,
runner: &mut Runner<M>,
runner: &mut Runner<M::Conn, M>,
record: Record<<M::Conn as AsyncDB>::ColumnType>,
format: bool,
) -> Result<()> {
Expand Down
16 changes: 5 additions & 11 deletions sqllogictest/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ where
}

/// Connections established in a [`Runner`](crate::Runner).
pub(crate) struct Connections<M: MakeConnection> {
pub(crate) struct Connections<D, M> {
make_conn: M,
conns: HashMap<ConnectionName, M::Conn>,
conns: HashMap<ConnectionName, D>,
}

impl<M: MakeConnection> Connections<M> {
impl<D: AsyncDB, M: MakeConnection<Conn = D>> Connections<D, M> {
pub fn new(make_conn: M) -> Self {
Connections {
make_conn,
Expand All @@ -48,10 +48,7 @@ impl<M: MakeConnection> Connections<M> {
}

/// Get a connection by name. Make a new connection if it doesn't exist.
pub async fn get(
&mut self,
name: ConnectionName,
) -> Result<&mut M::Conn, <M::Conn as AsyncDB>::Error> {
pub async fn get(&mut self, name: ConnectionName) -> Result<&mut D, D::Error> {
use std::collections::hash_map::Entry;

let conn = match self.conns.entry(name) {
Expand All @@ -68,10 +65,7 @@ impl<M: MakeConnection> Connections<M> {
/// Run a SQL statement on the default connection.
///
/// This is a shortcut for calling `get(Default)` then `run`.
pub async fn run_default(
&mut self,
sql: &str,
) -> Result<DBOutput<<M::Conn as AsyncDB>::ColumnType>, <M::Conn as AsyncDB>::Error> {
pub async fn run_default(&mut self, sql: &str) -> Result<DBOutput<D::ColumnType>, D::Error> {
self.get(ConnectionName::Default).await?.run(sql).await
}
}
39 changes: 15 additions & 24 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,11 @@ pub fn strict_column_validator<T: ColumnType>(actual: &Vec<T>, expected: &Vec<T>
}

/// Sqllogictest runner.
pub struct Runner<M: MakeConnection> {
conn: Connections<M>,
pub struct Runner<D: AsyncDB, M: MakeConnection> {
conn: Connections<D, M>,
// validator is used for validate if the result of query equals to expected.
validator: Validator,
column_type_validator: ColumnTypeValidator<<M::Conn as AsyncDB>::ColumnType>,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
testdir: Option<TempDir>,
sort_mode: Option<SortMode>,
/// 0 means never hashing
Expand All @@ -464,7 +464,7 @@ pub struct Runner<M: MakeConnection> {
labels: HashSet<String>,
}

impl<M: MakeConnection> Runner<M> {
impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
/// Create a new test runner on the database, with the given connection maker.
///
/// See [`MakeConnection`] for more details.
Expand Down Expand Up @@ -497,10 +497,7 @@ impl<M: MakeConnection> Runner<M> {
self.validator = validator;
}

pub fn with_column_validator(
&mut self,
validator: ColumnTypeValidator<<M::Conn as AsyncDB>::ColumnType>,
) {
pub fn with_column_validator(&mut self, validator: ColumnTypeValidator<D::ColumnType>) {
self.column_type_validator = validator;
}

Expand All @@ -510,8 +507,8 @@ impl<M: MakeConnection> Runner<M> {

pub async fn apply_record(
&mut self,
record: Record<<M::Conn as AsyncDB>::ColumnType>,
) -> RecordOutput<<M::Conn as AsyncDB>::ColumnType> {
record: Record<D::ColumnType>,
) -> RecordOutput<D::ColumnType> {
/// Returns whether we should skip this record, according to given `conditions`.
fn should_skip(
labels: &HashSet<String>,
Expand Down Expand Up @@ -651,7 +648,7 @@ impl<M: MakeConnection> Runner<M> {
}
}
Record::Sleep { duration, .. } => {
<M::Conn as AsyncDB>::sleep(duration).await;
D::sleep(duration).await;
RecordOutput::Nothing
}
Record::Control(control) => match control {
Expand All @@ -676,10 +673,7 @@ impl<M: MakeConnection> Runner<M> {
}

/// Run a single record.
pub async fn run_async(
&mut self,
record: Record<<M::Conn as AsyncDB>::ColumnType>,
) -> Result<(), TestError> {
pub async fn run_async(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
tracing::debug!(?record, "testing");

match (record.clone(), self.apply_record(record).await) {
Expand Down Expand Up @@ -828,10 +822,7 @@ impl<M: MakeConnection> Runner<M> {
}

/// Run a single record.
pub fn run(
&mut self,
record: Record<<M::Conn as AsyncDB>::ColumnType>,
) -> Result<(), TestError> {
pub fn run(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
futures::executor::block_on(self.run_async(record))
}

Expand All @@ -840,7 +831,7 @@ impl<M: MakeConnection> Runner<M> {
/// The runner will stop early once a halt record is seen.
pub async fn run_multi_async(
&mut self,
records: impl IntoIterator<Item = Record<<M::Conn as AsyncDB>::ColumnType>>,
records: impl IntoIterator<Item = Record<D::ColumnType>>,
) -> Result<(), TestError> {
for record in records.into_iter() {
if let Record::Halt { .. } = record {
Expand All @@ -856,7 +847,7 @@ impl<M: MakeConnection> Runner<M> {
/// The runner will stop early once a halt record is seen.
pub fn run_multi(
&mut self,
records: impl IntoIterator<Item = Record<<M::Conn as AsyncDB>::ColumnType>>,
records: impl IntoIterator<Item = Record<D::ColumnType>>,
) -> Result<(), TestError> {
block_on(self.run_multi_async(records))
}
Expand Down Expand Up @@ -912,7 +903,7 @@ impl<M: MakeConnection> Runner<M> {
jobs: usize,
) -> Result<(), ParallelTestError>
where
Fut: Future<Output = M::Conn>,
Fut: Future<Output = D>,
{
let files = glob::glob(glob).expect("failed to read glob pattern");
let mut tasks = vec![];
Expand Down Expand Up @@ -962,7 +953,7 @@ impl<M: MakeConnection> Runner<M> {
jobs: usize,
) -> Result<(), ParallelTestError>
where
Fut: Future<Output = M::Conn>,
Fut: Future<Output = D>,
{
block_on(self.run_parallel_async(glob, hosts, conn_builder, jobs))
}
Expand Down Expand Up @@ -990,7 +981,7 @@ impl<M: MakeConnection> Runner<M> {
filename: impl AsRef<Path>,
col_separator: &str,
validator: Validator,
column_type_validator: ColumnTypeValidator<<M::Conn as AsyncDB>::ColumnType>,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
) -> Result<(), Box<dyn std::error::Error>> {
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
Expand Down

0 comments on commit ffd63b6

Please sign in to comment.