diff --git a/sqlx-postgres/src/copy.rs b/sqlx-postgres/src/copy.rs index c6daea7774..6adb0fe3cd 100644 --- a/sqlx-postgres/src/copy.rs +++ b/sqlx-postgres/src/copy.rs @@ -342,16 +342,21 @@ async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( let stream: TryAsyncStream<'c, Bytes> = try_stream! { loop { - let msg = conn.stream.recv().await?; - match msg.format { - MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), - MessageFormat::CopyDone => { - let _ = msg.decode::()?; - conn.stream.recv_expect(MessageFormat::CommandComplete).await?; + match conn.stream.recv().await { + Err(e) => { conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; - return Ok(()) + return Err(e); }, - _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) + Ok(msg) => match msg.format { + MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), + MessageFormat::CopyDone => { + let _ = msg.decode::()?; + conn.stream.recv_expect(MessageFormat::CommandComplete).await?; + conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + return Ok(()) + }, + _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) + } } } }; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 91728bde8d..07f5e53188 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1,12 +1,15 @@ -use futures::{StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; + use sqlx::postgres::types::Oid; use sqlx::postgres::{ PgAdvisoryLock, PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgListener, PgPoolOptions, PgRow, PgSeverity, Postgres, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; +use sqlx_core::bytes::Bytes; use sqlx_test::{new, pool, setup_if_needed}; use std::env; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -382,6 +385,67 @@ async fn it_can_query_all_scalar() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn copy_can_work_with_failed_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // We're using a (local) statement_timeout to simulate a runtime failure, as opposed to + // a parse/plan failure. + let mut tx = conn.begin().await?; + let _ = sqlx::query("SELECT pg_catalog.set_config($1, $2, true)") + .bind("statement_timeout") + .bind("1ms") + .execute(tx.as_mut()) + .await?; + + let mut copy_out: Pin< + Box> + Send>, + > = (&mut tx) + .copy_out_raw("COPY (SELECT nspname FROM pg_catalog.pg_namespace WHERE pg_sleep(0.001) IS NULL) TO STDOUT") + .await?; + + while copy_out.try_next().await.is_ok() {} + drop(copy_out); + + tx.rollback().await?; + + // conn should be usable again, as we explictly rolled back the transaction + let got: i32 = sqlx::query_scalar("SELECT 1") + .fetch_one(conn.as_mut()) + .await?; + assert_eq!(1, got); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_work_with_failed_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // We're using a (local) statement_timeout to simulate a runtime failure, as opposed to + // a parse/plan failure. + let mut tx = conn.begin().await?; + let _ = sqlx::query("SELECT pg_catalog.set_config($1, $2, true)") + .bind("statement_timeout") + .bind("1ms") + .execute(tx.as_mut()) + .await?; + + assert!(sqlx::query("SELECT 1 WHERE pg_sleep(0.30) IS NULL") + .fetch_one(tx.as_mut()) + .await + .is_err()); + tx.rollback().await?; + + // conn should be usable again, as we explictly rolled back the transaction + let got: i32 = sqlx::query_scalar("SELECT 1") + .fetch_one(conn.as_mut()) + .await?; + assert_eq!(1, got); + + Ok(()) +} + #[sqlx_macros::test] async fn it_can_work_with_transactions() -> anyhow::Result<()> { let mut conn = new::().await?;