diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 02714e67c0bf..a8f8d1606506 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -16,6 +16,7 @@ // under the License. use arrow_flight::sql::server::PeekableFlightDataStream; +use arrow_flight::sql::DoPutPreparedStatementResult; use base64::prelude::BASE64_STANDARD; use base64::Engine; use futures::{stream, Stream, TryStreamExt}; @@ -619,7 +620,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandPreparedStatementQuery, _request: Request, - ) -> Result::DoPutStream>, Status> { + ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_query not implemented", )) diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs index 2b2f4af7ac90..01ea9b61a8f7 100644 --- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -808,6 +808,25 @@ pub struct DoPutUpdateResult { #[prost(int64, tag = "1")] pub record_count: i64, } +/// An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`. +/// +/// *Note on legacy behavior*: previous versions of the protocol did not return any result for +/// this command, and that behavior should still be supported by clients. In that case, the client +/// can continue as though the fields in this message were not provided or set to sensible default values. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DoPutPreparedStatementResult { + /// Represents a (potentially updated) opaque handle for the prepared statement on the server. + /// Because the handle could potentially be updated, any previous handles for this prepared + /// statement should be considered invalid, and all subsequent requests for this prepared + /// statement must use this new handle. + /// The updated handle allows implementing query parameters with stateless services. + /// + /// When an updated handle is not provided by the server, clients should contiue + /// using the previous handle provided by `ActionCreatePreparedStatementResonse`. + #[prost(bytes = "bytes", optional, tag = "1")] + pub prepared_statement_handle: ::core::option::Option<::prost::bytes::Bytes>, +} /// /// Request message for the "CancelQuery" action. /// diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index a014137f6fa9..44250fbe63e2 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -35,7 +35,8 @@ use crate::sql::{ CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo, + CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, + SqlInfo, }; use crate::trailers::extract_lazy_trailers; use crate::{ @@ -501,6 +502,7 @@ impl PreparedStatement { } /// Submit parameters to the server, if any have been set on this prepared statement instance + /// Updates our stored prepared statement handle with the handle given by the server response. async fn write_bind_params(&mut self) -> Result<(), ArrowError> { if let Some(ref params_batch) = self.parameter_binding { let cmd = CommandPreparedStatementQuery { @@ -519,17 +521,38 @@ impl PreparedStatement { .await .map_err(flight_error_to_arrow_error)?; - self.flight_sql_client + // Attempt to update the stored handle with any updated handle in the DoPut result. + // Older servers do not respond with a result for DoPut, so skip this step when + // the stream closes with no response. + if let Some(result) = self + .flight_sql_client .do_put(stream::iter(flight_data)) .await? - .try_collect::>() + .message() .await - .map_err(status_to_arrow_error)?; + .map_err(status_to_arrow_error)? + { + if let Some(handle) = self.unpack_prepared_statement_handle(&result)? { + self.handle = handle; + } + } } - Ok(()) } + /// Decodes the app_metadata stored in a [`PutResult`] as a + /// [`DoPutPreparedStatementResult`] and then returns + /// the inner prepared statement handle as [`Bytes`] + fn unpack_prepared_statement_handle( + &self, + put_result: &PutResult, + ) -> Result, ArrowError> { + let any = Any::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?; + Ok(any + .unpack::()? + .and_then(|result| result.prepared_statement_handle)) + } + /// Close the prepared statement, so that this PreparedStatement can not used /// anymore and server can free up any resources. pub async fn close(mut self) -> Result<(), ArrowError> { diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 97645ae7840d..089ee4dd8c3e 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -75,6 +75,7 @@ pub use gen::CommandPreparedStatementUpdate; pub use gen::CommandStatementQuery; pub use gen::CommandStatementSubstraitPlan; pub use gen::CommandStatementUpdate; +pub use gen::DoPutPreparedStatementResult; pub use gen::DoPutUpdateResult; pub use gen::Nullable; pub use gen::Searchable; @@ -251,6 +252,7 @@ prost_message_ext!( CommandStatementSubstraitPlan, CommandStatementUpdate, DoPutUpdateResult, + DoPutPreparedStatementResult, TicketStatementQuery, ); diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 0431e58111a4..c18024cf068a 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -33,7 +33,8 @@ use super::{ CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, - DoPutUpdateResult, ProstMessageExt, SqlInfo, TicketStatementQuery, + DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, + TicketStatementQuery, }; use crate::{ flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty, @@ -397,11 +398,15 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { } /// Bind parameters to given prepared statement. + /// + /// Returns an opaque handle that the client should pass + /// back to the server during subsequent requests with this + /// prepared statement. async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, _request: Request, - ) -> Result::DoPutStream>, Status> { + ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_query has no default implementation", )) @@ -709,7 +714,13 @@ where Ok(Response::new(Box::pin(output))) } Command::CommandPreparedStatementQuery(command) => { - self.do_put_prepared_statement_query(command, request).await + let result = self + .do_put_prepared_statement_query(command, request) + .await?; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) } Command::CommandStatementSubstraitPlan(command) => { let record_count = self.do_put_substrait_plan(command, request).await?; diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index cc270eeb6186..50a4ec0d8c66 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -32,17 +32,18 @@ use arrow_flight::{ CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, - CommandStatementUpdate, ProstMessageExt, SqlInfo, TicketStatementQuery, + CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, + TicketStatementQuery, }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket, + HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, }; use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; use assert_cmd::Command; use bytes::Bytes; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{Stream, TryStreamExt}; use prost::Message; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{Request, Response, Status, Streaming}; @@ -51,7 +52,7 @@ const QUERY: &str = "SELECT * FROM table;"; #[tokio::test] async fn test_simple() { - let test_server = FlightSqlServiceImpl {}; + let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(&test_server).await; let addr = fixture.addr; @@ -92,10 +93,9 @@ async fn test_simple() { const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; +const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle"; -#[tokio::test] -async fn test_do_put_prepared_statement() { - let test_server = FlightSqlServiceImpl {}; +async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { let fixture = TestFixture::new(&test_server).await; let addr = fixture.addr; @@ -136,11 +136,40 @@ async fn test_do_put_prepared_statement() { ); } +#[tokio::test] +pub async fn test_do_put_prepared_statement_stateless() { + test_do_put_prepared_statement(FlightSqlServiceImpl { + stateless_prepared_statements: true, + }) + .await +} + +#[tokio::test] +pub async fn test_do_put_prepared_statement_stateful() { + test_do_put_prepared_statement(FlightSqlServiceImpl { + stateless_prepared_statements: false, + }) + .await +} + /// All tests must complete within this many seconds or else the test server is shutdown const DEFAULT_TIMEOUT_SECONDS: u64 = 30; -#[derive(Clone, Default)] -pub struct FlightSqlServiceImpl {} +#[derive(Clone)] +pub struct FlightSqlServiceImpl { + /// Whether to emulate stateless (true) or stateful (false) behavior for + /// prepared statements. stateful servers will not return an updated + /// handle after executing `DoPut(CommandPreparedStatementQuery)` + stateless_prepared_statements: bool, +} + +impl Default for FlightSqlServiceImpl { + fn default() -> Self { + Self { + stateless_prepared_statements: true, + } + } +} impl FlightSqlServiceImpl { /// Return an [`FlightServiceServer`] that can be used with a @@ -274,10 +303,17 @@ impl FlightSqlService for FlightSqlServiceImpl { cmd: CommandPreparedStatementQuery, _request: Request, ) -> Result, Status> { - assert_eq!( - cmd.prepared_statement_handle, - PREPARED_STATEMENT_HANDLE.as_bytes() - ); + if self.stateless_prepared_statements { + assert_eq!( + cmd.prepared_statement_handle, + UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes() + ); + } else { + assert_eq!( + cmd.prepared_statement_handle, + PREPARED_STATEMENT_HANDLE.as_bytes() + ); + } let resp = Response::new(self.fake_flight_info().unwrap()); Ok(resp) } @@ -524,7 +560,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandPreparedStatementQuery, request: Request, - ) -> Result::DoPutStream>, Status> { + ) -> Result { // just make sure decoding the parameters works let parameters = FlightRecordBatchStream::new_from_flight_data( request.into_inner().map_err(|e| e.into()), @@ -543,10 +579,15 @@ impl FlightSqlService for FlightSqlServiceImpl { ))); } } - - Ok(Response::new( - futures::stream::once(async { Ok(PutResult::default()) }).boxed(), - )) + let handle = if self.stateless_prepared_statements { + UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into() + } else { + PREPARED_STATEMENT_HANDLE.to_string().into() + }; + let result = DoPutPreparedStatementResult { + prepared_statement_handle: Some(handle), + }; + Ok(result) } async fn do_put_prepared_statement_update( diff --git a/format/FlightSql.proto b/format/FlightSql.proto index f78e77e23278..4fc68f2a5db0 100644 --- a/format/FlightSql.proto +++ b/format/FlightSql.proto @@ -1796,7 +1796,27 @@ // an unknown updated record count. int64 record_count = 1; } - + + /* An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`. + * + * *Note on legacy behavior*: previous versions of the protocol did not return any result for + * this command, and that behavior should still be supported by clients. In that case, the client + * can continue as though the fields in this message were not provided or set to sensible default values. + */ + message DoPutPreparedStatementResult { + option (experimental) = true; + + // Represents a (potentially updated) opaque handle for the prepared statement on the server. + // Because the handle could potentially be updated, any previous handles for this prepared + // statement should be considered invalid, and all subsequent requests for this prepared + // statement must use this new handle. + // The updated handle allows implementing query parameters with stateless services. + // + // When an updated handle is not provided by the server, clients should contiue + // using the previous handle provided by `ActionCreatePreparedStatementResonse`. + optional bytes prepared_statement_handle = 1; + } + /* * Request message for the "CancelQuery" action. *