Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move slt clients into own module #2824

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions crates/glaredb/src/args/slt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ use std::time::Duration;
use anyhow::{anyhow, Result};
use clap::Args;
use pgsrv::auth::SingleUserAuthenticator;
use slt::test::{
ClientProtocol,
FlightSqlTestClient,
PgTestClient,
RpcTestClient,
Test,
TestClient,
TestHooks,
};
use slt::clients::flightsql::FlightSqlTestClient;
use slt::clients::postgres::PgTestClient;
use slt::clients::rpc::RpcTestClient;
use slt::clients::{ClientProtocol, TestClient};
use slt::test::{Test, TestHooks};
use tokio::net::TcpListener;
use tokio::runtime::Builder;
use tokio::sync::mpsc;
Expand Down
99 changes: 99 additions & 0 deletions crates/slt/src/clients/flightsql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use anyhow::Result;
use arrow_flight::sql::client::FlightSqlServiceClient;
use futures::StreamExt;
use pgrepr::format::Format;
use pgrepr::scalar::Scalar;
use pgrepr::types::arrow_to_pg_type;
use rpcsrv::flight::handler::FLIGHTSQL_DATABASE_HEADER;
use sqlexec::errors::ExecError;
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
use tokio_postgres::types::private::BytesMut;
use tokio_postgres::Config;
use tonic::async_trait;
use tonic::transport::{Channel, Endpoint};
use uuid::Uuid;

#[derive(Clone)]
pub struct FlightSqlTestClient {
pub client: FlightSqlServiceClient<Channel>,
}

impl FlightSqlTestClient {
pub async fn new(config: &Config) -> Result<Self> {
let port = config.get_ports().first().unwrap();
let addr = format!("http://0.0.0.0:{port}");
let conn = Endpoint::new(addr)?.connect().await?;
let dbid: Uuid = config.get_dbname().unwrap().parse().unwrap();

let mut client = FlightSqlServiceClient::new(conn);
client.set_header(FLIGHTSQL_DATABASE_HEADER, dbid.to_string());
Ok(FlightSqlTestClient { client })
}
}

#[async_trait]
impl AsyncDB for FlightSqlTestClient {
type Error = sqlexec::errors::ExecError;
type ColumnType = DefaultColumnType;
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
let mut output = Vec::new();
let mut num_columns = 0;

let mut client = self.client.clone();
let ticket = client.execute(sql.to_string(), None).await?;
let ticket = ticket
.endpoint
.first()
.ok_or_else(|| ExecError::String("The server should support this".to_string()))?
.clone();
let ticket = ticket.ticket.unwrap();
let mut stream = client.do_get(ticket).await?;

// all the remaining stream messages should be dictionary and record batches
while let Some(batch) = stream.next().await {
let batch = batch.map_err(|e| {
Self::Error::String(format!("error getting batch from flight: {e}"))
})?;

if num_columns == 0 {
num_columns = batch.num_columns();
}

for row_idx in 0..batch.num_rows() {
let mut row_output = Vec::with_capacity(num_columns);

for col in batch.columns() {
let pg_type = arrow_to_pg_type(col.data_type(), None);
let scalar = Scalar::try_from_array(col, row_idx, &pg_type)?;

if scalar.is_null() {
row_output.push("NULL".to_string());
} else {
let mut buf = BytesMut::new();
scalar.encode_with_format(Format::Text, &mut buf)?;

if buf.is_empty() {
row_output.push("(empty)".to_string())
} else {
let scalar = String::from_utf8(buf.to_vec()).map_err(|e| {
ExecError::Internal(format!(
"invalid text formatted result from pg encoder: {e}"
))
})?;
row_output.push(scalar.trim().to_owned());
}
}
}
output.push(row_output);
}
}
if output.is_empty() && num_columns == 0 {
Ok(DBOutput::StatementComplete(0))
} else {
Ok(DBOutput::Rows {
types: vec![DefaultColumnType::Text; num_columns],
rows: output,
})
}
}
}
85 changes: 85 additions & 0 deletions crates/slt/src/clients/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::time::Duration;

use anyhow::Result;
use clap::builder::PossibleValue;
use clap::ValueEnum;
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
use tonic::async_trait;

use self::flightsql::FlightSqlTestClient;
use self::postgres::PgTestClient;
use self::rpc::RpcTestClient;

pub mod flightsql;
pub mod postgres;
pub mod rpc;

#[derive(Clone)]
pub enum TestClient {
Pg(PgTestClient),
Rpc(RpcTestClient),
FlightSql(FlightSqlTestClient),
}

impl TestClient {
pub async fn close(self) -> Result<()> {
match self {
Self::Pg(pg_client) => pg_client.close().await,
Self::Rpc(_) => Ok(()),
Self::FlightSql(_) => Ok(()),
}
}
}

#[async_trait]
impl AsyncDB for TestClient {
type Error = sqlexec::errors::ExecError;
type ColumnType = DefaultColumnType;

async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
match self {
Self::Pg(pg_client) => pg_client.run(sql).await,
Self::Rpc(rpc_client) => rpc_client.run(sql).await,
Self::FlightSql(flight_client) => flight_client.run(sql).await,
}
}

fn engine_name(&self) -> &str {
match self {
Self::Pg { .. } => "glaredb_pg",
Self::Rpc { .. } => "glaredb_rpc",
Self::FlightSql { .. } => "glaredb_flight",
}
}

async fn sleep(dur: Duration) {
tokio::time::sleep(dur).await;
}
}

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum ClientProtocol {
// Connect over a local postgres instance
#[default]
Postgres,
// Connect over a local RPC instance
Rpc,
// Connect over a local FlightSql instance
FlightSql,
}

impl ValueEnum for ClientProtocol {
fn value_variants<'a>() -> &'a [Self] {
&[Self::Postgres, Self::Rpc, Self::FlightSql]
}

fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
Some(match self {
ClientProtocol::Postgres => PossibleValue::new("postgres"),
ClientProtocol::Rpc => PossibleValue::new("rpc"),
ClientProtocol::FlightSql => PossibleValue::new("flightsql")
.alias("flight-sql")
.alias("flight"),
})
}
}
95 changes: 95 additions & 0 deletions crates/slt/src/clients/postgres.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::ops::Deref;
use std::sync::Arc;

use anyhow::{anyhow, Result};
use sqlexec::errors::ExecError;
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
use tokio::sync::{oneshot, Mutex};
use tokio_postgres::{Client, Config, NoTls, SimpleQueryMessage};
use tonic::async_trait;

#[derive(Clone)]
pub struct PgTestClient {
client: Arc<Client>,
conn_err_rx: Arc<Mutex<oneshot::Receiver<Result<(), tokio_postgres::Error>>>>,
}

impl Deref for PgTestClient {
type Target = Client;
fn deref(&self) -> &Self::Target {
&self.client
}
}

impl PgTestClient {
pub async fn new(client_config: &Config) -> Result<Self> {
let (client, conn) = client_config.connect(NoTls).await?;
let (conn_err_tx, conn_err_rx) = oneshot::channel();
tokio::spawn(async move { conn_err_tx.send(conn.await) });
Ok(Self {
client: Arc::new(client),
conn_err_rx: Arc::new(Mutex::new(conn_err_rx)),
})
}

pub(super) async fn close(&self) -> Result<()> {
let PgTestClient { conn_err_rx, .. } = self;
let mut conn_err_rx = conn_err_rx.lock().await;

if let Ok(result) = conn_err_rx.try_recv() {
// Handle connection error
match result {
Ok(()) => Err(anyhow!("Client connection unexpectedly closed")),
Err(err) => Err(anyhow!("Client connection errored: {err}")),
}
} else {
Ok(())
}
}
}

#[async_trait]
impl AsyncDB for PgTestClient {
type Error = sqlexec::errors::ExecError;
type ColumnType = DefaultColumnType;
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
let mut output = Vec::new();
let mut num_columns = 0;

let rows = self
.simple_query(sql)
.await
.map_err(|e| ExecError::Internal(format!("cannot execute simple query: {e}")))?;
for row in rows {
match row {
SimpleQueryMessage::Row(row) => {
num_columns = row.len();
let mut row_output = Vec::with_capacity(row.len());
for i in 0..row.len() {
match row.get(i) {
Some(v) => {
if v.is_empty() {
row_output.push("(empty)".to_string());
} else {
row_output.push(v.to_string().trim().to_owned());
}
}
None => row_output.push("NULL".to_string()),
}
}
output.push(row_output);
}
SimpleQueryMessage::CommandComplete(_) => {}
_ => unreachable!(),
}
}
if output.is_empty() && num_columns == 0 {
Ok(DBOutput::StatementComplete(0))
} else {
Ok(DBOutput::Rows {
types: vec![DefaultColumnType::Text; num_columns],
rows: output,
})
}
}
}
Loading