Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MemCx<'mcx> for List<'mcx, T> #1342

Merged
30 changes: 28 additions & 2 deletions docs/src/pg-internal/memory-context.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@ will quickly be cleaned up, even in C extensions that don't have the power of Ru
memory management. However, this is incompatible with certain assumptions Rust makes about safety,
thus making it tricky to correctly bind this code.

<!-- TODO: finish out `MemCx` drafts and provide alternatives to worrying about allocations -->
Because the memory context's lifetime *also* defines the lifetime of the returned allocation, this
means that any code that returns something allocated in a memory context should be bound by
appropriate lifetime parameters that prevent usage of it beyond that point. As usual, a qualifying
covariant lifetime is often acceptable, meaning any "shorter" lifetime that ends before it is
actually deallocated. The in-progress type to describe this process is `MemCx<'mcx>`, but
understanding how to add to that requires understanding what C does and why it can cause problems
if those idioms are copied to Rust.

## What `palloc` calls to
## `palloc` and `CurrentMemoryContext`
In extension code, especially that written in C, you may notice calls to the following functions
for allocation and deallocation, instead of the usual `malloc` and `free`:

Expand Down Expand Up @@ -52,6 +58,26 @@ extern void *MemoryContextAllocAligned(MemoryContext context,
Notice that `pfree` only takes the pointer as an argument, effectively meaning every allocation
must know what context it belongs to in some way.

This also means that `palloc` is slightly troublesome. It is actually using global mutable state
to dynamically determine the lifetime of the allocations that are returned, which means that you
may only be able to assign it the lifetime of the current scope! This does not mean it is simply
a menace, however: a common C idiom is creating a new memory context, using it as the global
`CurrentMemoryContext`, then switching back to the previous context and destroying the transient
context, so as to prevent any memory leaks or other objects from escaping. This is both incredibly
useful and also means any concerns about the current context only lasting for the current scope
are actually quite realistic!

## `MemCx<'mcx>`

Fortunately, the behavior of memory contexts is still deterministic and strictly scope-based,
rather than using a dynamic graph of references as a garbage-collected language might. So,
we have a strategy available to make headway on translating ambiguous dynamic lifetimes into a
world of static lifetime constraints:
- Internally prefer the `MemoryContextAlloc` family of functions
- Use the `MemCx<'mcx>` type to "infects" newly allocated types with the lifetime `'mcx`
- Use functions accepting closures like `current_context` to allow obtaining temporary access to
the desired memory context, while constraining allocated lifetimes to the closure's scope!

### `CurrentMemoryContext` makes `impl Deref` hard

<!-- TODO: this segment. -->
Expand Down
51 changes: 26 additions & 25 deletions pgrx-tests/src/tests/list_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,49 @@
mod tests {
use crate as pgrx_tests;
use pgrx::list::List;
use pgrx::memcx;
use pgrx::prelude::*;

#[pg_test]
fn list_length_10() {
let mut list = List::Nil;
// Make sure the list length grows correctly:
for i in 0..10 {
unsafe {
list.unstable_push_in_context(i, pg_sys::CurrentMemoryContext);
memcx::current_context(|mcx| {
let mut list = List::Nil;
// Make sure the list length grows correctly:
for i in 0..10 {
list.unstable_push_in_context(i, mcx);
assert_eq!(i as usize + 1, list.len());
}
}
})
}

#[pg_test]
fn list_length_1000() {
let mut list = List::Nil;
// Make sure the list length grows correctly:
for i in 0..1000 {
unsafe {
list.unstable_push_in_context(i, pg_sys::CurrentMemoryContext);
memcx::current_context(|mcx| {
let mut list = List::Nil;
// Make sure the list length grows correctly:
for i in 0..1000 {
list.unstable_push_in_context(i, mcx);
assert_eq!(i as usize + 1, list.len());
}
}
})
}

#[pg_test]
fn list_length_drained() {
let mut list = List::Nil;
for i in 0..100 {
unsafe {
list.unstable_push_in_context(i, pg_sys::CurrentMemoryContext);
memcx::current_context(|mcx| {
let mut list = List::Nil;
for i in 0..100 {
list.unstable_push_in_context(i, mcx);
}
}

// Want to make sure the list length updates properly in the three major drain cases:
// from start of list, from inside the middle of the list, and from middle to tail.
let _ = list.drain(0..10);
assert_eq!(90, list.len());
let _ = list.drain(10..30);
assert_eq!(70, list.len());
let _ = list.drain(50..);
assert_eq!(50, list.len());
// Want to make sure the list length updates properly in the three major drain cases:
// from start of list, from inside the middle of the list, and from middle to tail.
let _ = list.drain(0..10);
assert_eq!(90, list.len());
let _ = list.drain(10..30);
assert_eq!(70, list.len());
let _ = list.drain(50..);
assert_eq!(50, list.len());
})
}
}
138 changes: 77 additions & 61 deletions pgrx/src/fn_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use pgrx_pg_sys::ffi::pg_guard_ffi_boundary;
use pgrx_pg_sys::PgTryBuilder;
use std::panic::AssertUnwindSafe;

use crate::memcx;
use crate::pg_catalog::pg_proc::{PgProc, ProArgMode, ProKind};
use crate::seal::Sealed;
use crate::{
Expand Down Expand Up @@ -293,7 +294,7 @@ pub fn fn_call_with_collation<R: FromDatum + IntoDatum>(

// setup the argument array
// SAFETY: `fcinfo_ref.args` is the over-allocated space we palloc0'd above. it's an array
// of `nargs` `NulalbleDatum` instances.
// of `nargs` `NullableDatum` instances.
let args_slice = fcinfo_ref.args.as_mut_slice(nargs);
for (i, datum) in arg_datums.into_iter().enumerate() {
assert!(!isstrict || (isstrict && datum.is_some())); // no NULL datums if this function is STRICT
Expand Down Expand Up @@ -332,69 +333,79 @@ fn lookup_fn(fname: &str, args: &[&dyn FnCallArg]) -> Result<pg_sys::Oid> {
// function following the normal SEARCH_PATH rules, ensuring its argument type Oids
// exactly match the ones from the user's input arguments. It does not evaluate the
// return type, so we'll have to do that later
let mut parts_list = List::<*mut std::ffi::c_void>::default();
let result = PgTryBuilder::new(AssertUnwindSafe(|| unsafe {
let arg_types = args.iter().map(|a| a.type_oid()).collect::<Vec<_>>();
let nargs: i16 = arg_types.len().try_into().map_err(|_| FnCallError::TooManyArguments)?;

// parse the function name into its possibly-qualified name parts
let ident_parts = parse_sql_ident(fname)?;
ident_parts
.iter_deny_null()
.map(|part| {
// SAFETY: `.as_pg_cstr()` palloc's a char* and `makeString` just takes ownership of it
pg_sys::makeString(part.as_pg_cstr())
})
.for_each(|part| {
parts_list.unstable_push_in_context(part.cast(), pg_sys::CurrentMemoryContext);
});

// look up an exact match based on the exact number of arguments we have
//
// SAFETY: we've allocated a PgList with the proper String node elements representing its name
// and we've allocated Vec of argument type oids which can be represented as a pointer.
let mut fnoid =
pg_sys::LookupFuncName(parts_list.as_mut_ptr(), nargs as _, arg_types.as_ptr(), true);

if fnoid == pg_sys::InvalidOid {
// if that didn't find a function, maybe we've got some defaults in there, so do a lookup
// where Postgres will consider that
fnoid = pg_sys::LookupFuncName(
memcx::current_context(|mcx| {
let mut parts_list = List::<*mut std::ffi::c_void>::default();
let result = PgTryBuilder::new(AssertUnwindSafe(|| unsafe {
let arg_types = args.iter().map(|a| a.type_oid()).collect::<Vec<_>>();
let nargs: i16 =
arg_types.len().try_into().map_err(|_| FnCallError::TooManyArguments)?;

// parse the function name into its possibly-qualified name parts
let ident_parts = parse_sql_ident(fname)?;
ident_parts
.iter_deny_null()
.map(|part| {
// SAFETY: `.as_pg_cstr()` palloc's a char* and `makeString` just takes ownership of it
pg_sys::makeString(part.as_pg_cstr())
})
.for_each(|part| {
parts_list.unstable_push_in_context(part.cast(), mcx);
});

// look up an exact match based on the exact number of arguments we have
//
// SAFETY: we've allocated a PgList with the proper String node elements representing its name
// and we've allocated Vec of argument type oids which can be represented as a pointer.
let mut fnoid = pg_sys::LookupFuncName(
parts_list.as_mut_ptr(),
-1,
nargs as _,
arg_types.as_ptr(),
false, // we want the ERROR here -- could be UNDEFINED_FUNCTION or AMBIGUOUS_FUNCTION
true,
);
}

Ok(fnoid)
}))
.catch_when(PgSqlErrorCode::ERRCODE_INVALID_PARAMETER_VALUE, |_| {
Err(FnCallError::InvalidIdentifier(fname.to_string()))
})
.catch_when(PgSqlErrorCode::ERRCODE_AMBIGUOUS_FUNCTION, |_| Err(FnCallError::AmbiguousFunction))
.catch_when(PgSqlErrorCode::ERRCODE_UNDEFINED_FUNCTION, |_| Err(FnCallError::UndefinedFunction))
.execute();

unsafe {
// SAFETY: we palloc'd the `pg_sys::String` elements of `parts_list` above and so it's
// safe for us to free them now that they're no longer being used
parts_list.drain(..).for_each(|s| {
#[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14"))]
{
let s = s.cast::<pg_sys::Value>();
pg_sys::pfree((*s).val.str_.cast());
}

#[cfg(any(feature = "pg15", feature = "pg16"))]
{
let s = s.cast::<pg_sys::String>();
pg_sys::pfree((*s).sval.cast());
if fnoid == pg_sys::InvalidOid {
// if that didn't find a function, maybe we've got some defaults in there, so do a lookup
// where Postgres will consider that
fnoid = pg_sys::LookupFuncName(
parts_list.as_mut_ptr(),
-1,
arg_types.as_ptr(),
false, // we want the ERROR here -- could be UNDEFINED_FUNCTION or AMBIGUOUS_FUNCTION
);
}
});
}

result
Ok(fnoid)
}))
.catch_when(PgSqlErrorCode::ERRCODE_INVALID_PARAMETER_VALUE, |_| {
Err(FnCallError::InvalidIdentifier(fname.to_string()))
})
.catch_when(PgSqlErrorCode::ERRCODE_AMBIGUOUS_FUNCTION, |_| {
Err(FnCallError::AmbiguousFunction)
})
.catch_when(PgSqlErrorCode::ERRCODE_UNDEFINED_FUNCTION, |_| {
Err(FnCallError::UndefinedFunction)
})
.execute();

unsafe {
// SAFETY: we palloc'd the `pg_sys::String` elements of `parts_list` above and so it's
// safe for us to free them now that they're no longer being used
parts_list.drain(..).for_each(|s| {
#[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14"))]
{
let s = s.cast::<pg_sys::Value>();
pg_sys::pfree((*s).val.str_.cast());
}

#[cfg(any(feature = "pg15", feature = "pg16"))]
{
let s = s.cast::<pg_sys::String>();
pg_sys::pfree((*s).sval.cast());
}
});
}
result
})
}

/// Parses an arbitrary string as if it is a SQL identifier. If it's not, [`FnCallError::InvalidIdentifier`]
Expand Down Expand Up @@ -428,9 +439,14 @@ fn create_default_value(pg_proc: &PgProc, argnum: usize) -> Result<Option<pg_sys
}

let default_argnum = argnum - non_default_args_cnt;
let default_value_tree = pg_proc.proargdefaults().ok_or(FnCallError::NoDefaultArguments)?;
let node =
default_value_tree.get(default_argnum).ok_or(FnCallError::NotDefaultArgument(argnum))?;
let node = memcx::current_context(|mcx| {
let default_value_tree =
pg_proc.proargdefaults(mcx).ok_or(FnCallError::NoDefaultArguments)?;
default_value_tree
.get(default_argnum)
.ok_or(FnCallError::NotDefaultArgument(argnum))
.copied()
})?;

unsafe {
// SAFETY: `arg_root` is okay to be the null pointer here, which indicates we don't care
Expand Down
1 change: 1 addition & 0 deletions pgrx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub mod itemptr;
pub mod iter;
pub mod list;
pub mod lwlock;
pub mod memcx;
pub mod memcxt;
pub mod misc;
#[cfg(feature = "cshim")]
Expand Down
37 changes: 23 additions & 14 deletions pgrx/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! It functions similarly to a Rust [`Vec`], including iterator support, but provides separate
//! understandings of [`List`][crate::pg_sys::List]s of [`pg_sys::Oid`]s, Integers, and Pointers.

use crate::memcx::MemCx;
use crate::pg_sys;
use crate::seal::Sealed;
use core::marker::PhantomData;
Expand All @@ -31,15 +32,15 @@ pub use old_list::*;
/// The List type from Postgres, lifted into Rust
/// Note: you may want the ListHead type
#[derive(Debug)]
pub enum List<T> {
pub enum List<'cx, T> {
Nil,
Cons(ListHead<T>),
Cons(ListHead<'cx, T>),
}

#[derive(Debug)]
pub struct ListHead<T> {
pub struct ListHead<'cx, T> {
list: NonNull<pg_sys::List>,
_type: PhantomData<[T]>,
_type: PhantomData<&'cx [T]>,
}

/// A strongly-typed ListCell
Expand Down Expand Up @@ -101,13 +102,13 @@ pub unsafe trait Enlist: Sealed + Sized {

/// Note the absence of `impl Default for ListHead`:
/// it must initialize at least 1 element to be created at all
impl<T> Default for List<T> {
fn default() -> List<T> {
impl<'cx, T> Default for List<'cx, T> {
fn default() -> List<'cx, T> {
List::Nil
}
}

impl<T: Enlist> List<T> {
impl<T: Enlist> List<'_, T> {
/// Attempt to obtain a `List<T>` from a `*mut pg_sys::List`
///
/// This may be somewhat confusing:
Expand All @@ -124,19 +125,23 @@ impl<T: Enlist> List<T> {
///
/// If it returns as `Some` and the List is more than zero length, it also asserts
/// that the entire List's `elements: *mut ListCell` is validly initialized as `T`
/// in each ListCell and that the List is allocated from a Postgres memory context.
/// in each ListCell and that the List is allocated from a MemCx that lasts
/// at least as long as the current context.
///
/// **Note:** This memory context must last long enough for your purposes.
/// YOU are responsible for bounding its lifetime correctly.
pub unsafe fn downcast_ptr(ptr: *mut pg_sys::List) -> Option<List<T>> {
pub unsafe fn downcast_ptr_in_memcx<'cx>(
ptr: *mut pg_sys::List,
memcx: &'cx MemCx<'_>,
) -> Option<List<'cx, T>> {
match NonNull::new(ptr) {
None => Some(List::Nil),
Some(list) => ListHead::downcast_ptr(list).map(|head| List::Cons(head)),
Some(list) => ListHead::downcast_ptr_in_memcx(list, memcx).map(|head| List::Cons(head)),
}
}
}

impl<T> List<T> {
impl<'cx, T> List<'cx, T> {
#[inline]
pub fn len(&self) -> usize {
match self {
Expand Down Expand Up @@ -172,7 +177,7 @@ impl<T> List<T> {
}
}

impl<T: Enlist> ListHead<T> {
impl<T: Enlist> ListHead<'_, T> {
/// From a non-nullable pointer that points to a valid List, produce a ListHead of the correct type
///
/// # Safety
Expand All @@ -181,11 +186,15 @@ impl<T: Enlist> ListHead<T> {
///
/// If it returns as `Some`, it also asserts the entire List is, across its length,
/// validly initialized as `T` in each ListCell.
pub unsafe fn downcast_ptr(list: NonNull<pg_sys::List>) -> Option<ListHead<T>> {
pub unsafe fn downcast_ptr_in_memcx<'cx>(
list: NonNull<pg_sys::List>,
_memcx: &'cx MemCx<'_>,
) -> Option<ListHead<'cx, T>> {
(T::LIST_TAG == (*list.as_ptr()).type_).then_some(ListHead { list, _type: PhantomData })
}
}
impl<T> ListHead<T> {

impl<T> ListHead<'_, T> {
#[inline]
pub fn len(&self) -> usize {
unsafe { self.list.as_ref().length as usize }
Expand Down
Loading