From 6a178cec22bc02454865ada2c83909b5b55be300 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Mon, 25 Mar 2024 13:28:32 -0500 Subject: [PATCH] refactor: move slt clients into own module --- crates/glaredb/src/args/slt.rs | 14 +- crates/slt/src/clients/flightsql.rs | 99 +++++++ crates/slt/src/clients/mod.rs | 85 ++++++ crates/slt/src/clients/postgres.rs | 95 +++++++ crates/slt/src/clients/rpc.rs | 134 ++++++++++ crates/slt/src/hooks.rs | 3 +- crates/slt/src/lib.rs | 1 + crates/slt/src/test.rs | 386 +--------------------------- crates/slt/src/tests.rs | 3 +- 9 files changed, 427 insertions(+), 393 deletions(-) create mode 100644 crates/slt/src/clients/flightsql.rs create mode 100644 crates/slt/src/clients/mod.rs create mode 100644 crates/slt/src/clients/postgres.rs create mode 100644 crates/slt/src/clients/rpc.rs diff --git a/crates/glaredb/src/args/slt.rs b/crates/glaredb/src/args/slt.rs index 915bc9d22..bcf57233d 100644 --- a/crates/glaredb/src/args/slt.rs +++ b/crates/glaredb/src/args/slt.rs @@ -6,15 +6,11 @@ use std::time::Duration; use anyhow::{anyhow, Result}; use clap::Args; use pgsrv::auth::SingleUserAuthenticator; -use slt::test::{ - ClientProtocol, - FlightSqlTestClient, - PgTestClient, - RpcTestClient, - Test, - TestClient, - TestHooks, -}; +use slt::clients::flightsql::FlightSqlTestClient; +use slt::clients::postgres::PgTestClient; +use slt::clients::rpc::RpcTestClient; +use slt::clients::{ClientProtocol, TestClient}; +use slt::test::{Test, TestHooks}; use tokio::net::TcpListener; use tokio::runtime::Builder; use tokio::sync::mpsc; diff --git a/crates/slt/src/clients/flightsql.rs b/crates/slt/src/clients/flightsql.rs new file mode 100644 index 000000000..3c613213f --- /dev/null +++ b/crates/slt/src/clients/flightsql.rs @@ -0,0 +1,99 @@ +use anyhow::Result; +use arrow_flight::sql::client::FlightSqlServiceClient; +use futures::StreamExt; +use pgrepr::format::Format; +use pgrepr::scalar::Scalar; +use pgrepr::types::arrow_to_pg_type; +use rpcsrv::flight::handler::FLIGHTSQL_DATABASE_HEADER; +use sqlexec::errors::ExecError; +use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; +use tokio_postgres::types::private::BytesMut; +use tokio_postgres::Config; +use tonic::async_trait; +use tonic::transport::{Channel, Endpoint}; +use uuid::Uuid; + +#[derive(Clone)] +pub struct FlightSqlTestClient { + pub client: FlightSqlServiceClient, +} + +impl FlightSqlTestClient { + pub async fn new(config: &Config) -> Result { + let port = config.get_ports().first().unwrap(); + let addr = format!("http://0.0.0.0:{port}"); + let conn = Endpoint::new(addr)?.connect().await?; + let dbid: Uuid = config.get_dbname().unwrap().parse().unwrap(); + + let mut client = FlightSqlServiceClient::new(conn); + client.set_header(FLIGHTSQL_DATABASE_HEADER, dbid.to_string()); + Ok(FlightSqlTestClient { client }) + } +} + +#[async_trait] +impl AsyncDB for FlightSqlTestClient { + type Error = sqlexec::errors::ExecError; + type ColumnType = DefaultColumnType; + async fn run(&mut self, sql: &str) -> Result, Self::Error> { + let mut output = Vec::new(); + let mut num_columns = 0; + + let mut client = self.client.clone(); + let ticket = client.execute(sql.to_string(), None).await?; + let ticket = ticket + .endpoint + .first() + .ok_or_else(|| ExecError::String("The server should support this".to_string()))? + .clone(); + let ticket = ticket.ticket.unwrap(); + let mut stream = client.do_get(ticket).await?; + + // all the remaining stream messages should be dictionary and record batches + while let Some(batch) = stream.next().await { + let batch = batch.map_err(|e| { + Self::Error::String(format!("error getting batch from flight: {e}")) + })?; + + if num_columns == 0 { + num_columns = batch.num_columns(); + } + + for row_idx in 0..batch.num_rows() { + let mut row_output = Vec::with_capacity(num_columns); + + for col in batch.columns() { + let pg_type = arrow_to_pg_type(col.data_type(), None); + let scalar = Scalar::try_from_array(col, row_idx, &pg_type)?; + + if scalar.is_null() { + row_output.push("NULL".to_string()); + } else { + let mut buf = BytesMut::new(); + scalar.encode_with_format(Format::Text, &mut buf)?; + + if buf.is_empty() { + row_output.push("(empty)".to_string()) + } else { + let scalar = String::from_utf8(buf.to_vec()).map_err(|e| { + ExecError::Internal(format!( + "invalid text formatted result from pg encoder: {e}" + )) + })?; + row_output.push(scalar.trim().to_owned()); + } + } + } + output.push(row_output); + } + } + if output.is_empty() && num_columns == 0 { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { + types: vec![DefaultColumnType::Text; num_columns], + rows: output, + }) + } + } +} diff --git a/crates/slt/src/clients/mod.rs b/crates/slt/src/clients/mod.rs new file mode 100644 index 000000000..d2309d95e --- /dev/null +++ b/crates/slt/src/clients/mod.rs @@ -0,0 +1,85 @@ +use std::time::Duration; + +use anyhow::Result; +use clap::builder::PossibleValue; +use clap::ValueEnum; +use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; +use tonic::async_trait; + +use self::flightsql::FlightSqlTestClient; +use self::postgres::PgTestClient; +use self::rpc::RpcTestClient; + +pub mod flightsql; +pub mod postgres; +pub mod rpc; + +#[derive(Clone)] +pub enum TestClient { + Pg(PgTestClient), + Rpc(RpcTestClient), + FlightSql(FlightSqlTestClient), +} + +impl TestClient { + pub async fn close(self) -> Result<()> { + match self { + Self::Pg(pg_client) => pg_client.close().await, + Self::Rpc(_) => Ok(()), + Self::FlightSql(_) => Ok(()), + } + } +} + +#[async_trait] +impl AsyncDB for TestClient { + type Error = sqlexec::errors::ExecError; + type ColumnType = DefaultColumnType; + + async fn run(&mut self, sql: &str) -> Result, Self::Error> { + match self { + Self::Pg(pg_client) => pg_client.run(sql).await, + Self::Rpc(rpc_client) => rpc_client.run(sql).await, + Self::FlightSql(flight_client) => flight_client.run(sql).await, + } + } + + fn engine_name(&self) -> &str { + match self { + Self::Pg { .. } => "glaredb_pg", + Self::Rpc { .. } => "glaredb_rpc", + Self::FlightSql { .. } => "glaredb_flight", + } + } + + async fn sleep(dur: Duration) { + tokio::time::sleep(dur).await; + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub enum ClientProtocol { + // Connect over a local postgres instance + #[default] + Postgres, + // Connect over a local RPC instance + Rpc, + // Connect over a local FlightSql instance + FlightSql, +} + +impl ValueEnum for ClientProtocol { + fn value_variants<'a>() -> &'a [Self] { + &[Self::Postgres, Self::Rpc, Self::FlightSql] + } + + fn to_possible_value(&self) -> Option { + Some(match self { + ClientProtocol::Postgres => PossibleValue::new("postgres"), + ClientProtocol::Rpc => PossibleValue::new("rpc"), + ClientProtocol::FlightSql => PossibleValue::new("flightsql") + .alias("flight-sql") + .alias("flight"), + }) + } +} diff --git a/crates/slt/src/clients/postgres.rs b/crates/slt/src/clients/postgres.rs new file mode 100644 index 000000000..0bb108899 --- /dev/null +++ b/crates/slt/src/clients/postgres.rs @@ -0,0 +1,95 @@ +use std::ops::Deref; +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use sqlexec::errors::ExecError; +use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; +use tokio::sync::{oneshot, Mutex}; +use tokio_postgres::{Client, Config, NoTls, SimpleQueryMessage}; +use tonic::async_trait; + +#[derive(Clone)] +pub struct PgTestClient { + client: Arc, + conn_err_rx: Arc>>>, +} + +impl Deref for PgTestClient { + type Target = Client; + fn deref(&self) -> &Self::Target { + &self.client + } +} + +impl PgTestClient { + pub async fn new(client_config: &Config) -> Result { + let (client, conn) = client_config.connect(NoTls).await?; + let (conn_err_tx, conn_err_rx) = oneshot::channel(); + tokio::spawn(async move { conn_err_tx.send(conn.await) }); + Ok(Self { + client: Arc::new(client), + conn_err_rx: Arc::new(Mutex::new(conn_err_rx)), + }) + } + + pub(super) async fn close(&self) -> Result<()> { + let PgTestClient { conn_err_rx, .. } = self; + let mut conn_err_rx = conn_err_rx.lock().await; + + if let Ok(result) = conn_err_rx.try_recv() { + // Handle connection error + match result { + Ok(()) => Err(anyhow!("Client connection unexpectedly closed")), + Err(err) => Err(anyhow!("Client connection errored: {err}")), + } + } else { + Ok(()) + } + } +} + +#[async_trait] +impl AsyncDB for PgTestClient { + type Error = sqlexec::errors::ExecError; + type ColumnType = DefaultColumnType; + async fn run(&mut self, sql: &str) -> Result, Self::Error> { + let mut output = Vec::new(); + let mut num_columns = 0; + + let rows = self + .simple_query(sql) + .await + .map_err(|e| ExecError::Internal(format!("cannot execute simple query: {e}")))?; + for row in rows { + match row { + SimpleQueryMessage::Row(row) => { + num_columns = row.len(); + let mut row_output = Vec::with_capacity(row.len()); + for i in 0..row.len() { + match row.get(i) { + Some(v) => { + if v.is_empty() { + row_output.push("(empty)".to_string()); + } else { + row_output.push(v.to_string().trim().to_owned()); + } + } + None => row_output.push("NULL".to_string()), + } + } + output.push(row_output); + } + SimpleQueryMessage::CommandComplete(_) => {} + _ => unreachable!(), + } + } + if output.is_empty() && num_columns == 0 { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { + types: vec![DefaultColumnType::Text; num_columns], + rows: output, + }) + } + } +} diff --git a/crates/slt/src/clients/rpc.rs b/crates/slt/src/clients/rpc.rs new file mode 100644 index 000000000..3d948d946 --- /dev/null +++ b/crates/slt/src/clients/rpc.rs @@ -0,0 +1,134 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Result; +use datafusion_ext::vars::SessionVars; +use futures::StreamExt; +use metastore::util::MetastoreClientMode; +use pgrepr::format::Format; +use pgrepr::scalar::Scalar; +use pgrepr::types::arrow_to_pg_type; +use sqlexec::engine::{Engine, EngineStorageConfig, SessionStorageConfig, TrackedSession}; +use sqlexec::errors::ExecError; +use sqlexec::remote::client::RemoteClient; +use sqlexec::session::ExecutionResult; +use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; +use telemetry::Tracker; +use tokio::sync::Mutex; +use tokio_postgres::types::private::BytesMut; +use tokio_postgres::Config; +use tonic::async_trait; +use uuid::Uuid; + +#[derive(Clone)] +pub struct RpcTestClient { + pub session: Arc>, + _engine: Arc, +} + +impl RpcTestClient { + pub async fn new(data_dir: PathBuf, config: &Config) -> Result { + let metastore = MetastoreClientMode::LocalInMemory.into_client().await?; + let storage = EngineStorageConfig::try_from_path_buf(&data_dir)?; + let engine = Engine::new(metastore, storage, Arc::new(Tracker::Nop), None).await?; + let port = config.get_ports().first().unwrap(); + + let addr = format!("http://0.0.0.0:{port}"); + let remote_client = RemoteClient::connect(addr.parse().unwrap()).await?; + let mut session = engine + .new_local_session_context(SessionVars::default(), SessionStorageConfig::default()) + .await?; + let test_id = Uuid::new_v4(); + session + .attach_remote_session(remote_client, Some(test_id)) + .await?; + Ok(RpcTestClient { + session: Arc::new(Mutex::new(session)), + _engine: Arc::new(engine), + }) + } +} + +#[async_trait] +impl AsyncDB for RpcTestClient { + type Error = sqlexec::errors::ExecError; + type ColumnType = DefaultColumnType; + async fn run(&mut self, sql: &str) -> Result, Self::Error> { + let mut output = Vec::new(); + let mut num_columns = 0; + let RpcTestClient { session, .. } = self; + + let mut session = session.lock().await; + const UNNAMED: String = String::new(); + let statements = session.parse_query(sql)?; + + for stmt in statements { + session.prepare_statement(UNNAMED, stmt, Vec::new()).await?; + let prepared = session.get_prepared_statement(&UNNAMED)?; + let num_fields = prepared.output_fields().map(|f| f.len()).unwrap_or(0); + session.bind_statement( + UNNAMED, + &UNNAMED, + Vec::new(), + vec![Format::Text; num_fields], + )?; + let stream = session.execute_portal(&UNNAMED, 0).await?; + + match stream { + ExecutionResult::Query { stream, .. } => { + let batches = stream + .collect::>() + .await + .into_iter() + .collect::, _>>()?; + + for batch in batches { + if num_columns == 0 { + num_columns = batch.num_columns(); + } + + for row_idx in 0..batch.num_rows() { + let mut row_output = Vec::with_capacity(num_columns); + + for col in batch.columns() { + let pg_type = arrow_to_pg_type(col.data_type(), None); + let scalar = Scalar::try_from_array(col, row_idx, &pg_type)?; + + if scalar.is_null() { + row_output.push("NULL".to_string()); + } else { + let mut buf = BytesMut::new(); + scalar.encode_with_format(Format::Text, &mut buf)?; + + if buf.is_empty() { + row_output.push("(empty)".to_string()) + } else { + let scalar = + String::from_utf8(buf.to_vec()).map_err(|e| { + ExecError::Internal(format!( + "invalid text formatted result from pg encoder: {e}" + )) + })?; + row_output.push(scalar.trim().to_owned()); + } + } + } + output.push(row_output); + } + } + } + ExecutionResult::Error(e) => return Err(e.into()), + _ => (), + } + } + + if output.is_empty() && num_columns == 0 { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { + types: vec![DefaultColumnType::Text; num_columns], + rows: output, + }) + } + } +} diff --git a/crates/slt/src/hooks.rs b/crates/slt/src/hooks.rs index 388eae4a2..c2bbdde20 100644 --- a/crates/slt/src/hooks.rs +++ b/crates/slt/src/hooks.rs @@ -11,7 +11,8 @@ use tokio::time::{sleep as tokio_sleep, Instant}; use tokio_postgres::{Client, Config}; use tracing::{error, info, warn}; -use super::test::{Hook, TestClient}; +use super::test::Hook; +use crate::clients::TestClient; /// This [`Hook`] is used to set some local variables that might change for /// each test. diff --git a/crates/slt/src/lib.rs b/crates/slt/src/lib.rs index 6ea7e5963..38c5362be 100644 --- a/crates/slt/src/lib.rs +++ b/crates/slt/src/lib.rs @@ -1,3 +1,4 @@ +pub mod clients; pub mod discovery; pub mod hooks; pub mod test; diff --git a/crates/slt/src/test.rs b/crates/slt/src/test.rs index e5358a9b7..d0cffc539 100644 --- a/crates/slt/src/test.rs +++ b/crates/slt/src/test.rs @@ -1,44 +1,16 @@ use std::collections::HashMap; use std::fmt::Debug; -use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::Duration; use anyhow::{anyhow, Result}; -use arrow_flight::sql::client::FlightSqlServiceClient; use async_trait::async_trait; -use clap::builder::PossibleValue; -use clap::ValueEnum; -use datafusion_ext::vars::SessionVars; -use futures::StreamExt; use glob::Pattern; -use metastore::util::MetastoreClientMode; -use pgrepr::format::Format; -use pgrepr::scalar::Scalar; -use pgrepr::types::arrow_to_pg_type; use regex::{Captures, Regex}; -use rpcsrv::flight::handler::FLIGHTSQL_DATABASE_HEADER; -use sqlexec::engine::{Engine, EngineStorageConfig, SessionStorageConfig, TrackedSession}; -use sqlexec::errors::ExecError; -use sqlexec::remote::client::RemoteClient; -use sqlexec::session::ExecutionResult; -use sqllogictest::{ - parse_with_name, - AsyncDB, - ColumnType, - DBOutput, - DefaultColumnType, - Injected, - Record, - Runner, -}; -use telemetry::Tracker; -use tokio::sync::{oneshot, Mutex}; -use tokio_postgres::types::private::BytesMut; -use tokio_postgres::{Client, Config, NoTls, SimpleQueryMessage}; -use tonic::transport::{Channel, Endpoint}; -use uuid::Uuid; +use sqllogictest::{parse_with_name, ColumnType, Injected, Record, Runner}; +use tokio_postgres::Config; + +use crate::clients::TestClient; #[async_trait] pub trait Hook: Send + Sync { @@ -202,353 +174,3 @@ fn parse_file( } Ok(records) } - -#[derive(Clone)] -pub struct PgTestClient { - client: Arc, - conn_err_rx: Arc>>>, -} - -impl Deref for PgTestClient { - type Target = Client; - fn deref(&self) -> &Self::Target { - &self.client - } -} - -impl PgTestClient { - pub async fn new(client_config: &Config) -> Result { - let (client, conn) = client_config.connect(NoTls).await?; - let (conn_err_tx, conn_err_rx) = oneshot::channel(); - tokio::spawn(async move { conn_err_tx.send(conn.await) }); - Ok(Self { - client: Arc::new(client), - conn_err_rx: Arc::new(Mutex::new(conn_err_rx)), - }) - } - - async fn close(&self) -> Result<()> { - let PgTestClient { conn_err_rx, .. } = self; - let mut conn_err_rx = conn_err_rx.lock().await; - - if let Ok(result) = conn_err_rx.try_recv() { - // Handle connection error - match result { - Ok(()) => Err(anyhow!("Client connection unexpectedly closed")), - Err(err) => Err(anyhow!("Client connection errored: {err}")), - } - } else { - Ok(()) - } - } -} - -#[derive(Clone)] -pub struct RpcTestClient { - session: Arc>, - _engine: Arc, -} - -impl RpcTestClient { - pub async fn new(data_dir: PathBuf, config: &Config) -> Result { - let metastore = MetastoreClientMode::LocalInMemory.into_client().await?; - let storage = EngineStorageConfig::try_from_path_buf(&data_dir)?; - let engine = Engine::new(metastore, storage, Arc::new(Tracker::Nop), None).await?; - let port = config.get_ports().first().unwrap(); - - let addr = format!("http://0.0.0.0:{port}"); - let remote_client = RemoteClient::connect(addr.parse().unwrap()).await?; - let mut session = engine - .new_local_session_context(SessionVars::default(), SessionStorageConfig::default()) - .await?; - let test_id = Uuid::new_v4(); - session - .attach_remote_session(remote_client, Some(test_id)) - .await?; - Ok(RpcTestClient { - session: Arc::new(Mutex::new(session)), - _engine: Arc::new(engine), - }) - } -} -#[derive(Clone)] -pub struct FlightSqlTestClient { - client: FlightSqlServiceClient, -} -impl FlightSqlTestClient { - pub async fn new(config: &Config) -> Result { - let port = config.get_ports().first().unwrap(); - let addr = format!("http://0.0.0.0:{port}"); - let conn = Endpoint::new(addr)?.connect().await?; - let dbid: Uuid = config.get_dbname().unwrap().parse().unwrap(); - - let mut client = FlightSqlServiceClient::new(conn); - client.set_header(FLIGHTSQL_DATABASE_HEADER, dbid.to_string()); - Ok(FlightSqlTestClient { client }) - } -} - -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] -pub enum ClientProtocol { - // Connect over a local postgres instance - #[default] - Postgres, - // Connect over a local RPC instance - Rpc, - // Connect over a local FlightSql instance - FlightSql, -} - -impl ValueEnum for ClientProtocol { - fn value_variants<'a>() -> &'a [Self] { - &[Self::Postgres, Self::Rpc, Self::FlightSql] - } - - fn to_possible_value(&self) -> Option { - Some(match self { - ClientProtocol::Postgres => PossibleValue::new("postgres"), - ClientProtocol::Rpc => PossibleValue::new("rpc"), - ClientProtocol::FlightSql => PossibleValue::new("flightsql").alias("flight-sql"), - }) - } -} - -#[derive(Clone)] -pub enum TestClient { - Pg(PgTestClient), - Rpc(RpcTestClient), - FlightSql(FlightSqlTestClient), -} - -impl TestClient { - pub async fn close(self) -> Result<()> { - match self { - Self::Pg(pg_client) => pg_client.close().await, - Self::Rpc(_) => Ok(()), - Self::FlightSql(_) => Ok(()), - } - } -} - -#[async_trait] -impl AsyncDB for PgTestClient { - type Error = sqlexec::errors::ExecError; - type ColumnType = DefaultColumnType; - async fn run(&mut self, sql: &str) -> Result, Self::Error> { - let mut output = Vec::new(); - let mut num_columns = 0; - - let rows = self - .simple_query(sql) - .await - .map_err(|e| ExecError::Internal(format!("cannot execute simple query: {e}")))?; - for row in rows { - match row { - SimpleQueryMessage::Row(row) => { - num_columns = row.len(); - let mut row_output = Vec::with_capacity(row.len()); - for i in 0..row.len() { - match row.get(i) { - Some(v) => { - if v.is_empty() { - row_output.push("(empty)".to_string()); - } else { - row_output.push(v.to_string().trim().to_owned()); - } - } - None => row_output.push("NULL".to_string()), - } - } - output.push(row_output); - } - SimpleQueryMessage::CommandComplete(_) => {} - _ => unreachable!(), - } - } - if output.is_empty() && num_columns == 0 { - Ok(DBOutput::StatementComplete(0)) - } else { - Ok(DBOutput::Rows { - types: vec![DefaultColumnType::Text; num_columns], - rows: output, - }) - } - } -} - -#[async_trait] -impl AsyncDB for RpcTestClient { - type Error = sqlexec::errors::ExecError; - type ColumnType = DefaultColumnType; - async fn run(&mut self, sql: &str) -> Result, Self::Error> { - let mut output = Vec::new(); - let mut num_columns = 0; - let RpcTestClient { session, .. } = self; - - let mut session = session.lock().await; - const UNNAMED: String = String::new(); - let statements = session.parse_query(sql)?; - - for stmt in statements { - session.prepare_statement(UNNAMED, stmt, Vec::new()).await?; - let prepared = session.get_prepared_statement(&UNNAMED)?; - let num_fields = prepared.output_fields().map(|f| f.len()).unwrap_or(0); - session.bind_statement( - UNNAMED, - &UNNAMED, - Vec::new(), - vec![Format::Text; num_fields], - )?; - let stream = session.execute_portal(&UNNAMED, 0).await?; - - match stream { - ExecutionResult::Query { stream, .. } => { - let batches = stream - .collect::>() - .await - .into_iter() - .collect::, _>>()?; - - for batch in batches { - if num_columns == 0 { - num_columns = batch.num_columns(); - } - - for row_idx in 0..batch.num_rows() { - let mut row_output = Vec::with_capacity(num_columns); - - for col in batch.columns() { - let pg_type = arrow_to_pg_type(col.data_type(), None); - let scalar = Scalar::try_from_array(col, row_idx, &pg_type)?; - - if scalar.is_null() { - row_output.push("NULL".to_string()); - } else { - let mut buf = BytesMut::new(); - scalar.encode_with_format(Format::Text, &mut buf)?; - - if buf.is_empty() { - row_output.push("(empty)".to_string()) - } else { - let scalar = - String::from_utf8(buf.to_vec()).map_err(|e| { - ExecError::Internal(format!( - "invalid text formatted result from pg encoder: {e}" - )) - })?; - row_output.push(scalar.trim().to_owned()); - } - } - } - output.push(row_output); - } - } - } - ExecutionResult::Error(e) => return Err(e.into()), - _ => (), - } - } - - if output.is_empty() && num_columns == 0 { - Ok(DBOutput::StatementComplete(0)) - } else { - Ok(DBOutput::Rows { - types: vec![DefaultColumnType::Text; num_columns], - rows: output, - }) - } - } -} - -#[async_trait] -impl AsyncDB for FlightSqlTestClient { - type Error = sqlexec::errors::ExecError; - type ColumnType = DefaultColumnType; - async fn run(&mut self, sql: &str) -> Result, Self::Error> { - let mut output = Vec::new(); - let mut num_columns = 0; - - let mut client = self.client.clone(); - let ticket = client.execute(sql.to_string(), None).await?; - let ticket = ticket - .endpoint - .first() - .ok_or_else(|| ExecError::String("The server should support this".to_string()))? - .clone(); - let ticket = ticket.ticket.unwrap(); - let mut stream = client.do_get(ticket).await?; - - // all the remaining stream messages should be dictionary and record batches - while let Some(batch) = stream.next().await { - let batch = batch.map_err(|e| { - Self::Error::String(format!("error getting batch from flight: {e}")) - })?; - - if num_columns == 0 { - num_columns = batch.num_columns(); - } - - for row_idx in 0..batch.num_rows() { - let mut row_output = Vec::with_capacity(num_columns); - - for col in batch.columns() { - let pg_type = arrow_to_pg_type(col.data_type(), None); - let scalar = Scalar::try_from_array(col, row_idx, &pg_type)?; - - if scalar.is_null() { - row_output.push("NULL".to_string()); - } else { - let mut buf = BytesMut::new(); - scalar.encode_with_format(Format::Text, &mut buf)?; - - if buf.is_empty() { - row_output.push("(empty)".to_string()) - } else { - let scalar = String::from_utf8(buf.to_vec()).map_err(|e| { - ExecError::Internal(format!( - "invalid text formatted result from pg encoder: {e}" - )) - })?; - row_output.push(scalar.trim().to_owned()); - } - } - } - output.push(row_output); - } - } - if output.is_empty() && num_columns == 0 { - Ok(DBOutput::StatementComplete(0)) - } else { - Ok(DBOutput::Rows { - types: vec![DefaultColumnType::Text; num_columns], - rows: output, - }) - } - } -} - -#[async_trait] -impl AsyncDB for TestClient { - type Error = sqlexec::errors::ExecError; - type ColumnType = DefaultColumnType; - - async fn run(&mut self, sql: &str) -> Result, Self::Error> { - match self { - Self::Pg(pg_client) => pg_client.run(sql).await, - Self::Rpc(rpc_client) => rpc_client.run(sql).await, - Self::FlightSql(flight_client) => flight_client.run(sql).await, - } - } - - fn engine_name(&self) -> &str { - match self { - Self::Pg { .. } => "glaredb_pg", - Self::Rpc { .. } => "glaredb_rpc", - Self::FlightSql { .. } => "glaredb_flight", - } - } - - async fn sleep(dur: Duration) { - tokio::time::sleep(dur).await; - } -} diff --git a/crates/slt/src/tests.rs b/crates/slt/src/tests.rs index 7d7d6c562..91c762196 100644 --- a/crates/slt/src/tests.rs +++ b/crates/slt/src/tests.rs @@ -5,7 +5,8 @@ use async_trait::async_trait; use tokio_postgres::Config; use tracing::warn; -use crate::test::{FnTest, TestClient}; +use crate::clients::TestClient; +use crate::test::FnTest; macro_rules! test_assert { ($e:expr, $err:expr) => {