Skip to content

Commit

Permalink
Merge pull request #1929 from Skgland/master
Browse files Browse the repository at this point in the history
detect and prevent concurrent AtomTable use
  • Loading branch information
mthom authored Jul 29, 2023
2 parents 31d17f1 + a701570 commit ca28c76
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 33 deletions.
36 changes: 29 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ rand = "0.8.5"
[dev-dependencies]
assert_cmd = "1.0.3"
predicates-core = "1.0.2"
serial_test = "0.5.1"
serial_test = "2.0.0"

[patch.crates-io]
modular-bitfield = { git = "https://github.com/mthom/modular-bitfield" }
Expand Down
81 changes: 58 additions & 23 deletions src/atom_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,42 +37,61 @@ impl From<bool> for Atom {
}
}

#[cfg(test)]
use std::cell::RefCell;

const ATOM_TABLE_INIT_SIZE: usize = 1 << 16;
const ATOM_TABLE_ALIGN: usize = 8;

#[cfg(test)]
thread_local! {
static ATOM_TABLE_BUF_BASE: RefCell<*const u8> = RefCell::new(ptr::null_mut());
static ATOM_TABLE_BUF_BASE: std::cell::RefCell<*const u8> = std::cell::RefCell::new(ptr::null_mut());
}

#[cfg(not(test))]
static mut ATOM_TABLE_BUF_BASE: *const u8 = ptr::null_mut();
static ATOM_TABLE_BUF_BASE: std::sync::atomic::AtomicPtr<u8> =
std::sync::atomic::AtomicPtr::new(ptr::null_mut());

fn set_atom_tbl_buf_base(old_ptr: *const u8, new_ptr: *const u8) -> Result<(), *const u8> {
#[cfg(test)]
fn set_atom_tbl_buf_base(ptr: *const u8) {
{
ATOM_TABLE_BUF_BASE.with(|atom_table_buf_base| {
*atom_table_buf_base.borrow_mut() = ptr;
});
let mut borrow = atom_table_buf_base.borrow_mut();
if *borrow != old_ptr {
Err(*borrow)
} else {
*borrow = new_ptr;
Ok(())
}
})?;
};
#[cfg(not(test))]
{
ATOM_TABLE_BUF_BASE
.compare_exchange(
old_ptr.cast_mut(),
new_ptr.cast_mut(),
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
)
.map_err(|ptr| ptr.cast_const())
}?;
Ok(())
}

#[cfg(test)]
pub(crate) fn get_atom_tbl_buf_base() -> *const u8 {
#[cfg(test)]
{
ATOM_TABLE_BUF_BASE.with(|atom_table_buf_base| *atom_table_buf_base.borrow())
}

#[cfg(not(test))]
fn set_atom_tbl_buf_base(ptr: *const u8) {
unsafe {
ATOM_TABLE_BUF_BASE = ptr;
{
ATOM_TABLE_BUF_BASE.load(std::sync::atomic::Ordering::Relaxed)
}
}

#[cfg(not(test))]
pub(crate) fn get_atom_tbl_buf_base() -> *const u8 {
unsafe { ATOM_TABLE_BUF_BASE }
#[test]
#[should_panic(expected = "Overwriting atom table base pointer")]
fn atomtable_is_not_concurrency_safe() {
let _table_a = AtomTable::new();
let _table_b = AtomTable::new();
}

impl RawBlockTraits for AtomTable {
Expand Down Expand Up @@ -239,22 +258,35 @@ pub struct AtomTable {
pub table: IndexSet<Atom>,
}

#[cold]
fn atom_table_base_pointer_mismatch(expected: *const u8, got: *const u8) -> ! {
assert_eq!(expected, got, "Overwriting atom table base pointer, expected old value to be {expected:p}, but found {got:p}");
unreachable!("This should only be called in a case of a mismatch as such the assert_eq should have failed!")
}

impl Drop for AtomTable {
fn drop(&mut self) {
if let Err(got) = set_atom_tbl_buf_base(self.block.base, ptr::null()) {
atom_table_base_pointer_mismatch(self.block.base, got);
}
self.block.deallocate();
}
}

impl AtomTable {
#[inline]
pub fn new() -> Self {
let table = Self {
block: RawBlock::new(),
table: IndexSet::new(),
};
let mut block = RawBlock::new();

set_atom_tbl_buf_base(table.block.base);
table
if let Err(got) = set_atom_tbl_buf_base(ptr::null(), block.base) {
block.deallocate();
atom_table_base_pointer_mismatch(ptr::null(), got);
}

Self {
block,
table: IndexSet::new(),
}
}

#[inline]
Expand Down Expand Up @@ -289,8 +321,11 @@ impl AtomTable {
ptr = self.block.alloc(size);

if ptr.is_null() {
let old_base = self.block.base;
self.block.grow();
set_atom_tbl_buf_base(self.block.base);
if let Err(got) = set_atom_tbl_buf_base(old_base, self.block.base) {
atom_table_base_pointer_mismatch(old_base, got);
}
} else {
break;
}
Expand Down
20 changes: 18 additions & 2 deletions tests/scryer/issues.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::helper::{load_module_test, run_top_level_test_no_args, run_top_level_test_with_args};
use scryer_prolog::machine::Machine;
use serial_test::serial;

// issue #857
Expand Down Expand Up @@ -128,10 +129,12 @@ fn compound_goal() {
// issue #815
#[test]
fn no_stutter() {
run_top_level_test_no_args("write(a), write(b), false.\n\
run_top_level_test_no_args(
"write(a), write(b), false.\n\
halt.\n\
",
"ab false.\n")
"ab false.\n",
)
}

/*
Expand Down Expand Up @@ -168,3 +171,16 @@ fn call_0() {
" error(existence_error(procedure,call/0),call/0).\n",
);
}

// issue #1206
#[serial]
#[test]
#[should_panic(expected = "Overwriting atom table base pointer")]
fn atomtable_is_not_concurrency_safe() {
// this is basically the same test as scryer_prolog::atom_table::atomtable_is_not_concurrency_safe
// but for this integration test scryer_prolog is compiled with cfg!(not(test)) while for the unit test it is compiled with cfg!(test)
// as the atom table implementation differ between cfg!(test) and cfg!(not(test)) both test serve a pourpose
// Note: this integration test itself is compiled with cfg!(test) independent of scryer_prolog itself
let _machine_a = Machine::with_test_streams();
let _machine_b = Machine::with_test_streams();
}

0 comments on commit ca28c76

Please sign in to comment.