Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flight sql do put handling, add bind parameter support to FlightSQL cli client #4797

Merged
merged 7 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::stream::Peekable;
use futures::{stream, Stream, TryStreamExt};
use once_cell::sync::Lazy;
use prost::Message;
Expand Down Expand Up @@ -602,15 +603,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<i64, Status> {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan not implemented",
Expand All @@ -620,7 +621,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query not implemented",
Expand All @@ -630,7 +631,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update not implemented",
Expand Down
104 changes: 92 additions & 12 deletions arrow-flight/src/bin/flight_sql_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::{sync::Arc, time::Duration};
use std::{error::Error, sync::Arc, time::Duration};

use arrow_array::RecordBatch;
use arrow_cast::pretty::pretty_format_batches;
use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions};
use arrow_flight::{
sql::client::FlightSqlServiceClient, utils::flight_data_to_batches, FlightData,
FlightInfo,
};
use arrow_schema::{ArrowError, Schema};
use clap::Parser;
use clap::{Parser, Subcommand};
use futures::TryStreamExt;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tracing_log::log::info;
Expand Down Expand Up @@ -98,8 +99,20 @@ struct Args {
#[clap(flatten)]
client_args: ClientArgs,

/// SQL query.
query: String,
#[clap(subcommand)]
cmd: Command,
}

#[derive(Debug, Subcommand)]
enum Command {
StatementQuery {
query: String,
},
PreparedStatementQuery {
query: String,
#[clap(short, value_parser = parse_key_val)]
params: Vec<(String, String)>,
},
}

#[tokio::main]
Expand All @@ -108,12 +121,50 @@ async fn main() {
setup_logging();
let mut client = setup_client(args.client_args).await.expect("setup client");

let info = client
.execute(args.query, None)
let flight_info = match args.cmd {
Command::StatementQuery { query } => client
.execute(query, None)
.await
.expect("execute statement"),
Command::PreparedStatementQuery { query, params } => {
let mut prepared_stmt = client
.prepare(query, None)
.await
.expect("prepare statement");

if !params.is_empty() {
prepared_stmt
.set_parameters(
construct_record_batch_from_params(
&params,
prepared_stmt
.parameter_schema()
.expect("get parameter schema"),
)
.expect("construct parameters"),
)
.expect("bind parameters")
}

prepared_stmt
.execute()
.await
.expect("execute prepared statement")
}
};

let batches = execute_flight(&mut client, flight_info)
.await
.expect("prepare statement");
info!("got flight info");
.expect("read flight data");

let res = pretty_format_batches(batches.as_slice()).expect("format results");
println!("{res}");
}

async fn execute_flight(
client: &mut FlightSqlServiceClient<Channel>,
info: FlightInfo,
) -> Result<Vec<RecordBatch>, ArrowError> {
let schema = Arc::new(Schema::try_from(info.clone()).expect("valid schema"));
let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
batches.push(RecordBatch::new_empty(schema));
Expand All @@ -134,8 +185,27 @@ async fn main() {
}
info!("received data");

let res = pretty_format_batches(batches.as_slice()).expect("format results");
println!("{res}");
Ok(batches)
}

fn construct_record_batch_from_params(
params: &[(String, String)],
parameter_schema: &Schema,
) -> Result<RecordBatch, ArrowError> {
let mut items = Vec::<(&String, ArrayRef)>::new();

for (name, value) in params {
let field = parameter_schema.field_with_name(name)?;
let value_as_array = StringArray::new_scalar(value);
let casted = cast_with_options(
value_as_array.get().0,
field.data_type(),
&CastOptions::default(),
)?;
items.push((name, casted))
}

RecordBatch::try_from_iter(items)
}

fn setup_logging() {
Expand Down Expand Up @@ -203,3 +273,13 @@ async fn setup_client(

Ok(client)
}

/// Parse a single key-value pair
fn parse_key_val(
s: &str,
) -> Result<(String, String), Box<dyn Error + Send + Sync + 'static>> {
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
54 changes: 51 additions & 3 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use std::collections::HashMap;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;

use crate::encode::FlightDataEncoderBuilder;
use crate::error::FlightError;
use crate::flight_service_client::FlightServiceClient;
use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
use crate::sql::{
Expand All @@ -32,8 +34,8 @@ use crate::sql::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
DoPutUpdateResult, ProstMessageExt, SqlInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
Expand Down Expand Up @@ -439,9 +441,12 @@ impl PreparedStatement<Channel> {

/// Executes the prepared statement query on the server.
pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
self.write_bind_params().await?;

let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};

let result = self
.flight_sql_client
.get_flight_info_for_command(cmd)
Expand All @@ -451,7 +456,9 @@ impl PreparedStatement<Channel> {

/// Executes the prepared statement update query on the server.
pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
let cmd = CommandPreparedStatementQuery {
self.write_bind_params().await?;

let cmd = CommandPreparedStatementUpdate {
Comment on lines -454 to +461
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also forgot to mention, I think this was a bug in the existing implementation. ExecuteUpdate should be performed with a CommandPreparedStatementUpdate command, not a CommandPreparedStatementQuery.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepared_statement_handle: self.handle.clone(),
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
Expand Down Expand Up @@ -492,6 +499,36 @@ impl PreparedStatement<Channel> {
Ok(())
}

/// Submit parameters to the server, if any have been set on this prepared statement instance
async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
if let Some(ref params_batch) = self.parameter_binding {
let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};

let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let flight_stream_builder = FlightDataEncoderBuilder::new()
.with_flight_descriptor(Some(descriptor))
.with_schema(params_batch.schema());
let flight_data = flight_stream_builder
.build(futures::stream::iter(
self.parameter_binding.clone().map(Ok),
))
.try_collect::<Vec<_>>()
.await
.map_err(flight_error_to_arrow_error)?;

self.flight_sql_client
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears consistent with the FlightSQL specification, it uses do_put to bind the parameter arguments. What isn't clear to me is if the result should be being used in some way.

This would seem to imply some sort of server-side state which I had perhaps expected FlightSQL to not rely on

Copy link
Contributor Author

@suremarc suremarc Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we are in agreement about it implying server-side state. FWIW FlightSQL also supports transactions which I think (maybe wrongly) would also require state. There was also some discussion happening about adding new RPC's for managing session state at some point (like a close RPC or something)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a fundamental flaw in FlightSQL tbh, gRPC is not a connection-oriented protocol and so the lifetime of any server state is non-deterministic... I believe @alamb plans to start a discussion to see if we can't fix this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed apache/arrow#37720 and will circulate this around

.do_put(stream::iter(flight_data))
.await?
.try_collect::<Vec<_>>()
.await
.map_err(status_to_arrow_error)?;
}

Ok(())
}

/// Close the prepared statement, so that this PreparedStatement can not used
/// anymore and server can free up any resources.
pub async fn close(mut self) -> Result<(), ArrowError> {
Expand All @@ -515,6 +552,17 @@ fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
ArrowError::IpcError(format!("{status:?}"))
}

fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
match err {
FlightError::Arrow(e) => e,
FlightError::NotYetImplemented(s) => ArrowError::NotYetImplemented(s),
FlightError::Tonic(status) => status_to_arrow_error(status),
FlightError::ProtocolError(e) => ArrowError::IpcError(e),
FlightError::DecodeError(s) => ArrowError::IpcError(s),
FlightError::ExternalError(e) => ArrowError::ExternalError(e),
}
suremarc marked this conversation as resolved.
Show resolved Hide resolved
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we return ArrowError from the Flight client (instead of FlightError), but I am trying to keep this PR scoped, so I just decided to stay consistent

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this should probably be FlightError


// A polymorphic structure to natively represent different types of data contained in `FlightData`
pub enum ArrowFlightData {
RecordBatch(RecordBatch),
Expand Down
24 changes: 16 additions & 8 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use std::pin::Pin;

use futures::Stream;
use futures::{stream::Peekable, Stream};
use prost::Message;
use tonic::{Request, Response, Status, Streaming};

Expand Down Expand Up @@ -366,7 +366,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
/// Implementors may override to handle additional calls to do_put()
async fn do_put_fallback(
&self,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option might be to pass the first ticket request as a separate argument. I don't feel strongly either way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a tough decision for me. I prefer using Peekable as it can be used as if it were the original stream, but I hate the fact that we have to leak its usage. We could pass the first FlightData as a separate argument, but it would require the user to chain it with the stream, if they wanted to use any APIs expecting a stream of FlightData. So I think I would stick with Peekable in the absence of any preference from others.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This server is the scaffolding to help people build flightsql servers -- they can always use the raw FlightService if they prefer. Thus I think the change in API is less critical and given the requirements it seems inevitable.

I think the only thing we should try and improve in this PR is improving the documentation to explain why peekable is used somehow (to lower the cognative burden on people trying to use this).

One potential option to document this is rather than using Peekable<Streaming<...>> dirctly, would make our own wrapper, PeekableStreaming or something. While this would require duplicate a bunch of the peekable API, there would be a natural place to document what it was for and how to use it which I think would lower the barrier to usage.

For example:

/// A wrapper around `Streaming` that allows inspection of the first message. 
/// This is needed because sometimes the first request in the stream will contain 
/// a [`FlightDescriptor`] in addition to potentially any data and the dispatch logic
/// must inspect this information. 
///
/// # example:
/// <show an example here of calling `into_inner()` to get the original data back
struct PeekableStreaming {
  inner: Peekable<Streaming<FlightData>>
}

impl PeekableStreaming {
  /// return the inner stream
  pub fn into_inner(self) -> Streaming<FlightData> { self.inner.into_inner() }
...
}

We could also potentially use something like BoxStream<FlightData> but that would lose the gRPC specific stuff like status codes and trailers exposed by Streaming as well as being an API change as well.

Thus I think this design is the best of several less than ideal solutions. To proceed perhaps we can add some documentation on the do_*_fallback methods that mentions the stream comes from peekable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new type, PeekableFlightDataStream, which exposes into_inner and peek, similarly to Peekable. I think this is a good enough subset of functionality for FlightSQL use cases, and if users need access to more of the lower-level functionality of Peekable, they can call PeekableFlightDataStream::into_peekable.

message: Any,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(format!(
Expand All @@ -379,7 +379,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_statement_update has no default implementation",
Expand All @@ -390,7 +390,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query has no default implementation",
Expand All @@ -401,7 +401,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update has no default implementation",
Expand All @@ -412,7 +412,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_substrait_plan(
&self,
_query: CommandStatementSubstraitPlan,
_request: Request<Streaming<FlightData>>,
_request: Request<Peekable<Streaming<FlightData>>>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan has no default implementation",
Expand Down Expand Up @@ -688,9 +688,17 @@ where

async fn do_put(
&self,
mut request: Request<Streaming<FlightData>>,
request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
let cmd = request.get_mut().message().await?.unwrap();
// See issue #4658: https://github.com/apache/arrow-rs/issues/4658
// To dispatch to the correct `do_put` method, we cannot discard the first message,
// as it may contain the Arrow schema, which the `do_put` handler may need.
// To allow the first message to be reused by the `do_put` handler,
// we wrap this stream in a `Peekable` one, which allows us to peek at
// the first message without discarding it.
let mut request = request.map(futures::StreamExt::peekable);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I am following correctly, the issue is do_put accepts a FlightData stream, but the first request in the stream will contain a FlightDescriptor in addition to potentially any data. I continue to be utterly baffled by the design of Flight 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is my understanding as well. Prior to this change, decoding a flight stream inside one of the do_put methods would give you an error like Received RecordBatch prior to schema

let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?;

let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd)
.map_err(decode_error_to_status)?;
match Command::try_from(message).map_err(arrow_error_to_status)? {
Expand Down
Loading
Loading