Skip to content

Commit

Permalink
Optimize code and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
XeniaLu committed Feb 13, 2024
1 parent 77b491a commit c9c8937
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
21 changes: 21 additions & 0 deletions pgrx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod tests {

use pgrx::prelude::*;
use pgrx::spi;
use pgrx::spi::Query;

#[pg_test(error = "syntax error at or near \"THIS\"")]
fn test_spi_failure() -> Result<(), spi::Error> {
Expand Down Expand Up @@ -419,6 +420,26 @@ mod tests {
})
}

#[pg_test(error = "CREATE TABLE is not allowed in a non-volatile function")]
fn test_execute_prepared_statement_in_readonly() -> Result<(), spi::Error> {
Spi::connect(|client| {
let stmt = client.prepare("CREATE TABLE a ()", None)?;
// This is supposed to run in read-only
stmt.execute(&client, Some(1), None)?;
Ok(())
})
}

#[pg_test]
fn test_execute_prepared_statement_in_readwrite() -> Result<(), spi::Error> {
Spi::connect(|client| {
let stmt = client.prepare_mut("CREATE TABLE a ()", None)?;
// This is supposed to run in read-write
stmt.execute(&client, Some(1), None)?;
Ok(())
})
}

#[pg_test]
fn test_spi_select_sees_update() -> spi::Result<()> {
let with_select = Spi::connect(|mut client| {
Expand Down
10 changes: 2 additions & 8 deletions pgrx/src/spi/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,6 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
limit: Option<libc::c_long>,
arguments: Self::Arguments,
) -> SpiResult<SpiTupleTable<'conn>> {
if self.mutating {
Spi::mark_mutable();
}
// SAFETY: no concurrent access
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
Expand All @@ -256,7 +253,7 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_mut_ptr(),
Spi::is_xact_still_immutable(),
!self.mutating && Spi::is_xact_still_immutable(),
limit.unwrap_or(0),
)
};
Expand All @@ -265,9 +262,6 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
}

fn open_cursor(self, _client: &SpiClient<'conn>, args: Self::Arguments) -> SpiCursor<'conn> {
if self.mutating {
Spi::mark_mutable();
}
let args = args.unwrap_or_default();

let (mut datums, nulls): (Vec<_>, Vec<_>) = args.into_iter().map(prepare_datum).unzip();
Expand All @@ -280,7 +274,7 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
!self.mutating && Spi::is_xact_still_immutable(),
))
};
SpiCursor { ptr, __marker: PhantomData }
Expand Down

0 comments on commit c9c8937

Please sign in to comment.