-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: move slt clients into own module (#2824)
small PR to just move all of the clients inside `crates/slt/src/test.rs` into their own module `clients`. No functional changes, just some refactoring to make the code a bit more organized.
- Loading branch information
1 parent
ebac267
commit 8a943f5
Showing
9 changed files
with
427 additions
and
393 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
use anyhow::Result; | ||
use arrow_flight::sql::client::FlightSqlServiceClient; | ||
use futures::StreamExt; | ||
use pgrepr::format::Format; | ||
use pgrepr::scalar::Scalar; | ||
use pgrepr::types::arrow_to_pg_type; | ||
use rpcsrv::flight::handler::FLIGHTSQL_DATABASE_HEADER; | ||
use sqlexec::errors::ExecError; | ||
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; | ||
use tokio_postgres::types::private::BytesMut; | ||
use tokio_postgres::Config; | ||
use tonic::async_trait; | ||
use tonic::transport::{Channel, Endpoint}; | ||
use uuid::Uuid; | ||
|
||
#[derive(Clone)] | ||
pub struct FlightSqlTestClient { | ||
pub client: FlightSqlServiceClient<Channel>, | ||
} | ||
|
||
impl FlightSqlTestClient { | ||
pub async fn new(config: &Config) -> Result<Self> { | ||
let port = config.get_ports().first().unwrap(); | ||
let addr = format!("http://0.0.0.0:{port}"); | ||
let conn = Endpoint::new(addr)?.connect().await?; | ||
let dbid: Uuid = config.get_dbname().unwrap().parse().unwrap(); | ||
|
||
let mut client = FlightSqlServiceClient::new(conn); | ||
client.set_header(FLIGHTSQL_DATABASE_HEADER, dbid.to_string()); | ||
Ok(FlightSqlTestClient { client }) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl AsyncDB for FlightSqlTestClient { | ||
type Error = sqlexec::errors::ExecError; | ||
type ColumnType = DefaultColumnType; | ||
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> { | ||
let mut output = Vec::new(); | ||
let mut num_columns = 0; | ||
|
||
let mut client = self.client.clone(); | ||
let ticket = client.execute(sql.to_string(), None).await?; | ||
let ticket = ticket | ||
.endpoint | ||
.first() | ||
.ok_or_else(|| ExecError::String("The server should support this".to_string()))? | ||
.clone(); | ||
let ticket = ticket.ticket.unwrap(); | ||
let mut stream = client.do_get(ticket).await?; | ||
|
||
// all the remaining stream messages should be dictionary and record batches | ||
while let Some(batch) = stream.next().await { | ||
let batch = batch.map_err(|e| { | ||
Self::Error::String(format!("error getting batch from flight: {e}")) | ||
})?; | ||
|
||
if num_columns == 0 { | ||
num_columns = batch.num_columns(); | ||
} | ||
|
||
for row_idx in 0..batch.num_rows() { | ||
let mut row_output = Vec::with_capacity(num_columns); | ||
|
||
for col in batch.columns() { | ||
let pg_type = arrow_to_pg_type(col.data_type(), None); | ||
let scalar = Scalar::try_from_array(col, row_idx, &pg_type)?; | ||
|
||
if scalar.is_null() { | ||
row_output.push("NULL".to_string()); | ||
} else { | ||
let mut buf = BytesMut::new(); | ||
scalar.encode_with_format(Format::Text, &mut buf)?; | ||
|
||
if buf.is_empty() { | ||
row_output.push("(empty)".to_string()) | ||
} else { | ||
let scalar = String::from_utf8(buf.to_vec()).map_err(|e| { | ||
ExecError::Internal(format!( | ||
"invalid text formatted result from pg encoder: {e}" | ||
)) | ||
})?; | ||
row_output.push(scalar.trim().to_owned()); | ||
} | ||
} | ||
} | ||
output.push(row_output); | ||
} | ||
} | ||
if output.is_empty() && num_columns == 0 { | ||
Ok(DBOutput::StatementComplete(0)) | ||
} else { | ||
Ok(DBOutput::Rows { | ||
types: vec![DefaultColumnType::Text; num_columns], | ||
rows: output, | ||
}) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
use std::time::Duration; | ||
|
||
use anyhow::Result; | ||
use clap::builder::PossibleValue; | ||
use clap::ValueEnum; | ||
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; | ||
use tonic::async_trait; | ||
|
||
use self::flightsql::FlightSqlTestClient; | ||
use self::postgres::PgTestClient; | ||
use self::rpc::RpcTestClient; | ||
|
||
pub mod flightsql; | ||
pub mod postgres; | ||
pub mod rpc; | ||
|
||
#[derive(Clone)] | ||
pub enum TestClient { | ||
Pg(PgTestClient), | ||
Rpc(RpcTestClient), | ||
FlightSql(FlightSqlTestClient), | ||
} | ||
|
||
impl TestClient { | ||
pub async fn close(self) -> Result<()> { | ||
match self { | ||
Self::Pg(pg_client) => pg_client.close().await, | ||
Self::Rpc(_) => Ok(()), | ||
Self::FlightSql(_) => Ok(()), | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl AsyncDB for TestClient { | ||
type Error = sqlexec::errors::ExecError; | ||
type ColumnType = DefaultColumnType; | ||
|
||
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> { | ||
match self { | ||
Self::Pg(pg_client) => pg_client.run(sql).await, | ||
Self::Rpc(rpc_client) => rpc_client.run(sql).await, | ||
Self::FlightSql(flight_client) => flight_client.run(sql).await, | ||
} | ||
} | ||
|
||
fn engine_name(&self) -> &str { | ||
match self { | ||
Self::Pg { .. } => "glaredb_pg", | ||
Self::Rpc { .. } => "glaredb_rpc", | ||
Self::FlightSql { .. } => "glaredb_flight", | ||
} | ||
} | ||
|
||
async fn sleep(dur: Duration) { | ||
tokio::time::sleep(dur).await; | ||
} | ||
} | ||
|
||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] | ||
pub enum ClientProtocol { | ||
// Connect over a local postgres instance | ||
#[default] | ||
Postgres, | ||
// Connect over a local RPC instance | ||
Rpc, | ||
// Connect over a local FlightSql instance | ||
FlightSql, | ||
} | ||
|
||
impl ValueEnum for ClientProtocol { | ||
fn value_variants<'a>() -> &'a [Self] { | ||
&[Self::Postgres, Self::Rpc, Self::FlightSql] | ||
} | ||
|
||
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> { | ||
Some(match self { | ||
ClientProtocol::Postgres => PossibleValue::new("postgres"), | ||
ClientProtocol::Rpc => PossibleValue::new("rpc"), | ||
ClientProtocol::FlightSql => PossibleValue::new("flightsql") | ||
.alias("flight-sql") | ||
.alias("flight"), | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use std::ops::Deref; | ||
use std::sync::Arc; | ||
|
||
use anyhow::{anyhow, Result}; | ||
use sqlexec::errors::ExecError; | ||
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType}; | ||
use tokio::sync::{oneshot, Mutex}; | ||
use tokio_postgres::{Client, Config, NoTls, SimpleQueryMessage}; | ||
use tonic::async_trait; | ||
|
||
#[derive(Clone)] | ||
pub struct PgTestClient { | ||
client: Arc<Client>, | ||
conn_err_rx: Arc<Mutex<oneshot::Receiver<Result<(), tokio_postgres::Error>>>>, | ||
} | ||
|
||
impl Deref for PgTestClient { | ||
type Target = Client; | ||
fn deref(&self) -> &Self::Target { | ||
&self.client | ||
} | ||
} | ||
|
||
impl PgTestClient { | ||
pub async fn new(client_config: &Config) -> Result<Self> { | ||
let (client, conn) = client_config.connect(NoTls).await?; | ||
let (conn_err_tx, conn_err_rx) = oneshot::channel(); | ||
tokio::spawn(async move { conn_err_tx.send(conn.await) }); | ||
Ok(Self { | ||
client: Arc::new(client), | ||
conn_err_rx: Arc::new(Mutex::new(conn_err_rx)), | ||
}) | ||
} | ||
|
||
pub(super) async fn close(&self) -> Result<()> { | ||
let PgTestClient { conn_err_rx, .. } = self; | ||
let mut conn_err_rx = conn_err_rx.lock().await; | ||
|
||
if let Ok(result) = conn_err_rx.try_recv() { | ||
// Handle connection error | ||
match result { | ||
Ok(()) => Err(anyhow!("Client connection unexpectedly closed")), | ||
Err(err) => Err(anyhow!("Client connection errored: {err}")), | ||
} | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl AsyncDB for PgTestClient { | ||
type Error = sqlexec::errors::ExecError; | ||
type ColumnType = DefaultColumnType; | ||
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> { | ||
let mut output = Vec::new(); | ||
let mut num_columns = 0; | ||
|
||
let rows = self | ||
.simple_query(sql) | ||
.await | ||
.map_err(|e| ExecError::Internal(format!("cannot execute simple query: {e}")))?; | ||
for row in rows { | ||
match row { | ||
SimpleQueryMessage::Row(row) => { | ||
num_columns = row.len(); | ||
let mut row_output = Vec::with_capacity(row.len()); | ||
for i in 0..row.len() { | ||
match row.get(i) { | ||
Some(v) => { | ||
if v.is_empty() { | ||
row_output.push("(empty)".to_string()); | ||
} else { | ||
row_output.push(v.to_string().trim().to_owned()); | ||
} | ||
} | ||
None => row_output.push("NULL".to_string()), | ||
} | ||
} | ||
output.push(row_output); | ||
} | ||
SimpleQueryMessage::CommandComplete(_) => {} | ||
_ => unreachable!(), | ||
} | ||
} | ||
if output.is_empty() && num_columns == 0 { | ||
Ok(DBOutput::StatementComplete(0)) | ||
} else { | ||
Ok(DBOutput::Rows { | ||
types: vec![DefaultColumnType::Text; num_columns], | ||
rows: output, | ||
}) | ||
} | ||
} | ||
} |
Oops, something went wrong.