Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spi: Add support for specifying mutating mode to PreparedStatement #1275

Merged
merged 2 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pgrx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod tests {

use pgrx::prelude::*;
use pgrx::spi;
use pgrx::spi::Query;

#[pg_test(error = "syntax error at or near \"THIS\"")]
fn test_spi_failure() -> Result<(), spi::Error> {
Expand Down Expand Up @@ -419,6 +420,26 @@ mod tests {
})
}

#[pg_test(error = "CREATE TABLE is not allowed in a non-volatile function")]
fn test_execute_prepared_statement_in_readonly() -> Result<(), spi::Error> {
Spi::connect(|client| {
let stmt = client.prepare("CREATE TABLE a ()", None)?;
// This is supposed to run in read-only
stmt.execute(&client, Some(1), None)?;
Ok(())
})
}

#[pg_test]
fn test_execute_prepared_statement_in_readwrite() -> Result<(), spi::Error> {
Spi::connect(|client| {
let stmt = client.prepare_mut("CREATE TABLE a ()", None)?;
// This is supposed to run in read-write
stmt.execute(&client, Some(1), None)?;
Ok(())
})
}

#[pg_test]
fn test_spi_select_sees_update() -> spi::Result<()> {
let with_select = Spi::connect(|mut client| {
Expand Down
23 changes: 23 additions & 0 deletions pgrx/src/spi/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ impl<'conn> SpiClient<'conn> {
&self,
query: &str,
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>> {
self.make_prepare_statement(query, args, false)
}

/// Prepares a mutating statement that is valid for the lifetime of the client
///
/// # Panics
///
/// This function will panic if the supplied `query` string contained a NULL byte
pub fn prepare_mut(
&self,
query: &str,
args: Option<Vec<PgOid>>,
) -> SpiResult<PreparedStatement<'conn>> {
self.make_prepare_statement(query, args, true)
}

fn make_prepare_statement(
&self,
query: &str,
args: Option<Vec<PgOid>>,
mutating: bool,
) -> SpiResult<PreparedStatement<'conn>> {
let src = CString::new(query).expect("query contained a null byte");
let args = args.unwrap_or_default();
Expand All @@ -43,6 +65,7 @@ impl<'conn> SpiClient<'conn> {
.unwrap()
})?,
__marker: PhantomData,
mutating,
})
}

Expand Down
11 changes: 8 additions & 3 deletions pgrx/src/spi/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl<'conn> Query<'conn> for &str {
pub struct PreparedStatement<'conn> {
pub(super) plan: NonNull<pg_sys::_SPI_plan>,
pub(super) __marker: PhantomData<&'conn ()>,
pub(super) mutating: bool,
}

/// Static lifetime-bound prepared statement
Expand Down Expand Up @@ -214,7 +215,11 @@ impl<'conn> PreparedStatement<'conn> {
unsafe {
pg_sys::SPI_keepplan(self.plan.as_ptr());
}
OwnedPreparedStatement(PreparedStatement { __marker: PhantomData, plan: self.plan })
OwnedPreparedStatement(PreparedStatement {
__marker: PhantomData,
plan: self.plan,
mutating: self.mutating,
})
}
}

Expand Down Expand Up @@ -248,7 +253,7 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_mut_ptr(),
Spi::is_xact_still_immutable(),
!self.mutating && Spi::is_xact_still_immutable(),
limit.unwrap_or(0),
)
};
Expand All @@ -269,7 +274,7 @@ impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
!self.mutating && Spi::is_xact_still_immutable(),
))
};
SpiCursor { ptr, __marker: PhantomData }
Expand Down
Loading