Skip to content

Commit

Permalink
refactor: move slt clients into own module (#2824)
Browse files Browse the repository at this point in the history
small PR to just move all of the clients inside `crates/slt/src/test.rs`
into their own module `clients`.

No functional changes, just some refactoring to make the code a bit more
organized.
  • Loading branch information
universalmind303 authored Mar 25, 2024
1 parent ebac267 commit 8a943f5
Show file tree
Hide file tree
Showing 9 changed files with 427 additions and 393 deletions.
14 changes: 5 additions & 9 deletions crates/glaredb/src/args/slt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ use std::time::Duration;
use anyhow::{anyhow, Result};
use clap::Args;
use pgsrv::auth::SingleUserAuthenticator;
use slt::test::{
ClientProtocol,
FlightSqlTestClient,
PgTestClient,
RpcTestClient,
Test,
TestClient,
TestHooks,
};
use slt::clients::flightsql::FlightSqlTestClient;
use slt::clients::postgres::PgTestClient;
use slt::clients::rpc::RpcTestClient;
use slt::clients::{ClientProtocol, TestClient};
use slt::test::{Test, TestHooks};
use tokio::net::TcpListener;
use tokio::runtime::Builder;
use tokio::sync::mpsc;
Expand Down
99 changes: 99 additions & 0 deletions crates/slt/src/clients/flightsql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use anyhow::Result;
use arrow_flight::sql::client::FlightSqlServiceClient;
use futures::StreamExt;
use pgrepr::format::Format;
use pgrepr::scalar::Scalar;
use pgrepr::types::arrow_to_pg_type;
use rpcsrv::flight::handler::FLIGHTSQL_DATABASE_HEADER;
use sqlexec::errors::ExecError;
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
use tokio_postgres::types::private::BytesMut;
use tokio_postgres::Config;
use tonic::async_trait;
use tonic::transport::{Channel, Endpoint};
use uuid::Uuid;

#[derive(Clone)]
pub struct FlightSqlTestClient {
pub client: FlightSqlServiceClient<Channel>,
}

impl FlightSqlTestClient {
pub async fn new(config: &Config) -> Result<Self> {
let port = config.get_ports().first().unwrap();
let addr = format!("http://0.0.0.0:{port}");
let conn = Endpoint::new(addr)?.connect().await?;
let dbid: Uuid = config.get_dbname().unwrap().parse().unwrap();

let mut client = FlightSqlServiceClient::new(conn);
client.set_header(FLIGHTSQL_DATABASE_HEADER, dbid.to_string());
Ok(FlightSqlTestClient { client })
}
}

#[async_trait]
impl AsyncDB for FlightSqlTestClient {
type Error = sqlexec::errors::ExecError;
type ColumnType = DefaultColumnType;
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
let mut output = Vec::new();
let mut num_columns = 0;

let mut client = self.client.clone();
let ticket = client.execute(sql.to_string(), None).await?;
let ticket = ticket
.endpoint
.first()
.ok_or_else(|| ExecError::String("The server should support this".to_string()))?
.clone();
let ticket = ticket.ticket.unwrap();
let mut stream = client.do_get(ticket).await?;

// all the remaining stream messages should be dictionary and record batches
while let Some(batch) = stream.next().await {
let batch = batch.map_err(|e| {
Self::Error::String(format!("error getting batch from flight: {e}"))
})?;

if num_columns == 0 {
num_columns = batch.num_columns();
}

for row_idx in 0..batch.num_rows() {
let mut row_output = Vec::with_capacity(num_columns);

for col in batch.columns() {
let pg_type = arrow_to_pg_type(col.data_type(), None);
let scalar = Scalar::try_from_array(col, row_idx, &pg_type)?;

if scalar.is_null() {
row_output.push("NULL".to_string());
} else {
let mut buf = BytesMut::new();
scalar.encode_with_format(Format::Text, &mut buf)?;

if buf.is_empty() {
row_output.push("(empty)".to_string())
} else {
let scalar = String::from_utf8(buf.to_vec()).map_err(|e| {
ExecError::Internal(format!(
"invalid text formatted result from pg encoder: {e}"
))
})?;
row_output.push(scalar.trim().to_owned());
}
}
}
output.push(row_output);
}
}
if output.is_empty() && num_columns == 0 {
Ok(DBOutput::StatementComplete(0))
} else {
Ok(DBOutput::Rows {
types: vec![DefaultColumnType::Text; num_columns],
rows: output,
})
}
}
}
85 changes: 85 additions & 0 deletions crates/slt/src/clients/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::time::Duration;

use anyhow::Result;
use clap::builder::PossibleValue;
use clap::ValueEnum;
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
use tonic::async_trait;

use self::flightsql::FlightSqlTestClient;
use self::postgres::PgTestClient;
use self::rpc::RpcTestClient;

pub mod flightsql;
pub mod postgres;
pub mod rpc;

#[derive(Clone)]
pub enum TestClient {
Pg(PgTestClient),
Rpc(RpcTestClient),
FlightSql(FlightSqlTestClient),
}

impl TestClient {
pub async fn close(self) -> Result<()> {
match self {
Self::Pg(pg_client) => pg_client.close().await,
Self::Rpc(_) => Ok(()),
Self::FlightSql(_) => Ok(()),
}
}
}

#[async_trait]
impl AsyncDB for TestClient {
type Error = sqlexec::errors::ExecError;
type ColumnType = DefaultColumnType;

async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
match self {
Self::Pg(pg_client) => pg_client.run(sql).await,
Self::Rpc(rpc_client) => rpc_client.run(sql).await,
Self::FlightSql(flight_client) => flight_client.run(sql).await,
}
}

fn engine_name(&self) -> &str {
match self {
Self::Pg { .. } => "glaredb_pg",
Self::Rpc { .. } => "glaredb_rpc",
Self::FlightSql { .. } => "glaredb_flight",
}
}

async fn sleep(dur: Duration) {
tokio::time::sleep(dur).await;
}
}

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum ClientProtocol {
// Connect over a local postgres instance
#[default]
Postgres,
// Connect over a local RPC instance
Rpc,
// Connect over a local FlightSql instance
FlightSql,
}

impl ValueEnum for ClientProtocol {
fn value_variants<'a>() -> &'a [Self] {
&[Self::Postgres, Self::Rpc, Self::FlightSql]
}

fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
Some(match self {
ClientProtocol::Postgres => PossibleValue::new("postgres"),
ClientProtocol::Rpc => PossibleValue::new("rpc"),
ClientProtocol::FlightSql => PossibleValue::new("flightsql")
.alias("flight-sql")
.alias("flight"),
})
}
}
95 changes: 95 additions & 0 deletions crates/slt/src/clients/postgres.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::ops::Deref;
use std::sync::Arc;

use anyhow::{anyhow, Result};
use sqlexec::errors::ExecError;
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
use tokio::sync::{oneshot, Mutex};
use tokio_postgres::{Client, Config, NoTls, SimpleQueryMessage};
use tonic::async_trait;

#[derive(Clone)]
pub struct PgTestClient {
client: Arc<Client>,
conn_err_rx: Arc<Mutex<oneshot::Receiver<Result<(), tokio_postgres::Error>>>>,
}

impl Deref for PgTestClient {
type Target = Client;
fn deref(&self) -> &Self::Target {
&self.client
}
}

impl PgTestClient {
pub async fn new(client_config: &Config) -> Result<Self> {
let (client, conn) = client_config.connect(NoTls).await?;
let (conn_err_tx, conn_err_rx) = oneshot::channel();
tokio::spawn(async move { conn_err_tx.send(conn.await) });
Ok(Self {
client: Arc::new(client),
conn_err_rx: Arc::new(Mutex::new(conn_err_rx)),
})
}

pub(super) async fn close(&self) -> Result<()> {
let PgTestClient { conn_err_rx, .. } = self;
let mut conn_err_rx = conn_err_rx.lock().await;

if let Ok(result) = conn_err_rx.try_recv() {
// Handle connection error
match result {
Ok(()) => Err(anyhow!("Client connection unexpectedly closed")),
Err(err) => Err(anyhow!("Client connection errored: {err}")),
}
} else {
Ok(())
}
}
}

#[async_trait]
impl AsyncDB for PgTestClient {
type Error = sqlexec::errors::ExecError;
type ColumnType = DefaultColumnType;
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
let mut output = Vec::new();
let mut num_columns = 0;

let rows = self
.simple_query(sql)
.await
.map_err(|e| ExecError::Internal(format!("cannot execute simple query: {e}")))?;
for row in rows {
match row {
SimpleQueryMessage::Row(row) => {
num_columns = row.len();
let mut row_output = Vec::with_capacity(row.len());
for i in 0..row.len() {
match row.get(i) {
Some(v) => {
if v.is_empty() {
row_output.push("(empty)".to_string());
} else {
row_output.push(v.to_string().trim().to_owned());
}
}
None => row_output.push("NULL".to_string()),
}
}
output.push(row_output);
}
SimpleQueryMessage::CommandComplete(_) => {}
_ => unreachable!(),
}
}
if output.is_empty() && num_columns == 0 {
Ok(DBOutput::StatementComplete(0))
} else {
Ok(DBOutput::Rows {
types: vec![DefaultColumnType::Text; num_columns],
rows: output,
})
}
}
}
Loading

0 comments on commit 8a943f5

Please sign in to comment.