diff --git a/lucet-runtime/lucet-runtime-internals/src/hostcall_macros.rs b/lucet-runtime/lucet-runtime-internals/src/hostcall_macros.rs index 1ff975fc4..8445f6a7d 100644 --- a/lucet-runtime/lucet-runtime-internals/src/hostcall_macros.rs +++ b/lucet-runtime/lucet-runtime-internals/src/hostcall_macros.rs @@ -38,6 +38,8 @@ macro_rules! lucet_hostcalls { vmctx_raw: *mut $crate::vmctx::lucet_vmctx, $( $arg: $arg_ty ),* ) -> $ret_ty { + use $crate::vmctx::VmctxInternal; + #[inline(always)] unsafe fn hostcall_impl( $vmctx: &mut $crate::vmctx::Vmctx, @@ -45,12 +47,29 @@ macro_rules! lucet_hostcalls { ) -> $ret_ty { $($body)* } + + let mut vmctx = $crate::vmctx::Vmctx::from_raw(vmctx_raw); + + // increment the nesting level before calling the implementation... + vmctx.increment_hostcall_nesting(); let res = std::panic::catch_unwind(move || { hostcall_impl(&mut $crate::vmctx::Vmctx::from_raw(vmctx_raw), $( $arg ),*) }); + // and decrement it afterwards, whether or not there was a panic + vmctx.decrement_hostcall_nesting(); + + // get this as a stack variable so that vmctx doesn't leak if we terminate below + let in_nested_hostcall = vmctx.in_nested_hostcall(); + drop(vmctx); + match res { Ok(res) => res, Err(e) => { + // only terminate once we've unwound through all hostcall segments of the + // guest stack + if in_nested_hostcall { + std::panic::resume_unwind(e); + } if let Some(details) = e.downcast_ref::<$crate::instance::TerminationDetails>() { let mut vmctx = $crate::vmctx::Vmctx::from_raw(vmctx_raw); vmctx.terminate_no_unwind(details.clone()); diff --git a/lucet-runtime/lucet-runtime-internals/src/instance.rs b/lucet-runtime/lucet-runtime-internals/src/instance.rs index add2e1d0d..6e35987e0 100644 --- a/lucet-runtime/lucet-runtime-internals/src/instance.rs +++ b/lucet-runtime/lucet-runtime-internals/src/instance.rs @@ -216,6 +216,16 @@ pub struct Instance { /// Pointer to the function used as the entrypoint (for use in backtraces) entrypoint: Option, + /// The number of nested hostcalls currently present on the guest stack. + /// + /// Primarily used when implementing instance termination, this represents the number of times + /// unwinding must continue in order to unwind through all hostcall segments of the guest stack. + /// + /// For example, if the guest calls `host_fn1()`, which in turn calls back into `guest_fn1()`, + /// which calls `host_fn2()`, the value of this field must be `2` while `host_fn2()` is + /// executing. This number is automatically managed by the `lucet_hostcalls!` macro. + pub(crate) hostcall_nesting: usize, + /// `_padding` must be the last member of the structure. /// This marks where the padding starts to make the structure exactly 4096 bytes long. /// It is also used to compute the size of the structure up to that point, i.e. without padding. @@ -517,6 +527,7 @@ impl Instance { c_fatal_handler: None, signal_handler: Box::new(signal_handler_none) as Box, entrypoint: None, + hostcall_nesting: 0, _padding: (), }; inst.set_globals_ptr(globals_ptr); diff --git a/lucet-runtime/lucet-runtime-internals/src/vmctx.rs b/lucet-runtime/lucet-runtime-internals/src/vmctx.rs index 5c7d5bab4..e3f871929 100644 --- a/lucet-runtime/lucet-runtime-internals/src/vmctx.rs +++ b/lucet-runtime/lucet-runtime-internals/src/vmctx.rs @@ -58,6 +58,20 @@ pub trait VmctxInternal { /// you could not use orthogonal `&mut` refs that come from `Vmctx`, like the heap or /// terminating the instance. unsafe fn instance_mut(&self) -> &mut Instance; + + /// Increment the hostcall nesting level. + /// + /// This must be done whenever entering a hostcall implementation. + fn increment_hostcall_nesting(&self); + + /// Decrement the hostcall nesting level. + /// + /// This must be done whenever a hostcall implementation returns, but before any unwinding logic + /// is evaluated. + fn decrement_hostcall_nesting(&self); + + /// Returns `true` if there are hostcall stack segments present on the guest stack. + fn in_nested_hostcall(&self) -> bool; } impl VmctxInternal for Vmctx { @@ -68,6 +82,27 @@ impl VmctxInternal for Vmctx { unsafe fn instance_mut(&self) -> &mut Instance { instance_from_vmctx(self.vmctx) } + + fn increment_hostcall_nesting(&self) { + let inst = unsafe { self.instance_mut() }; + inst.hostcall_nesting = inst + .hostcall_nesting + .checked_add(1) + .expect("hostcall nesting level overflowed"); + } + + fn decrement_hostcall_nesting(&self) { + let inst = unsafe { self.instance_mut() }; + debug_assert!(inst.hostcall_nesting > 0); + inst.hostcall_nesting = inst + .hostcall_nesting + .checked_sub(1) + .expect("hostcall nesting level underflowed"); + } + + fn in_nested_hostcall(&self) -> bool { + self.instance().hostcall_nesting > 0 + } } impl Vmctx { diff --git a/lucet-runtime/lucet-runtime-tests/guests/host/bindings.json b/lucet-runtime/lucet-runtime-tests/guests/host/bindings.json index 383bcf316..5ab5464ef 100644 --- a/lucet-runtime/lucet-runtime-tests/guests/host/bindings.json +++ b/lucet-runtime/lucet-runtime-tests/guests/host/bindings.json @@ -2,6 +2,8 @@ "env": { "hostcall_test_func_hello": "hostcall_test_func_hello", "hostcall_test_func_hostcall_error": "hostcall_test_func_hostcall_error", - "hostcall_test_func_hostcall_error_unwind": "hostcall_test_func_hostcall_error_unwind" + "hostcall_test_func_hostcall_error_unwind": "hostcall_test_func_hostcall_error_unwind", + "hostcall_test_func_hostcall_nested_error_unwind1": "hostcall_test_func_hostcall_nested_error_unwind1", + "hostcall_test_func_hostcall_nested_error_unwind2": "hostcall_test_func_hostcall_nested_error_unwind2" } } diff --git a/lucet-runtime/lucet-runtime-tests/guests/host/hostcall_nested_error_unwind.c b/lucet-runtime/lucet-runtime-tests/guests/host/hostcall_nested_error_unwind.c new file mode 100644 index 000000000..b40fb1627 --- /dev/null +++ b/lucet-runtime/lucet-runtime-tests/guests/host/hostcall_nested_error_unwind.c @@ -0,0 +1,14 @@ +#include + +extern void hostcall_test_func_hostcall_nested_error_unwind1(void (*)(void)); +extern void hostcall_test_func_hostcall_nested_error_unwind2(void); + +void guest_func(void) { + hostcall_test_func_hostcall_nested_error_unwind2(); +} + +int main(void) +{ + hostcall_test_func_hostcall_nested_error_unwind1(guest_func); + return 0; +} diff --git a/lucet-runtime/lucet-runtime-tests/src/host.rs b/lucet-runtime/lucet-runtime-tests/src/host.rs index b84f48c32..c95ab5341 100644 --- a/lucet-runtime/lucet-runtime-tests/src/host.rs +++ b/lucet-runtime/lucet-runtime-tests/src/host.rs @@ -27,6 +27,8 @@ macro_rules! host_tests { lazy_static! { static ref HOSTCALL_MUTEX: Mutex<()> = Mutex::new(()); + static ref HOSTCALL_MUTEX_1: Mutex<()> = Mutex::new(()); + static ref HOSTCALL_MUTEX_2: Mutex<()> = Mutex::new(()); } lucet_hostcalls! { @@ -66,6 +68,37 @@ macro_rules! host_tests { drop(lock); } + #[no_mangle] + pub unsafe extern "C" fn hostcall_test_func_hostcall_nested_error_unwind1( + &mut vmctx, + cb_idx: u32, + ) -> () { + let lock = HOSTCALL_MUTEX_1.lock().unwrap(); + + let func = vmctx + .get_func_from_idx(0, cb_idx) + .expect("can get function by index"); + let func = std::mem::transmute::< + usize, + extern "C" fn(*mut lucet_vmctx) + >(func.ptr.as_usize()); + (func)(vmctx.as_raw()); + + drop(lock); + } + + #[allow(unreachable_code)] + #[no_mangle] + pub unsafe extern "C" fn hostcall_test_func_hostcall_nested_error_unwind2( + &mut vmctx, + ) -> () { + let lock = HOSTCALL_MUTEX_2.lock().unwrap(); + unsafe { + lucet_hostcall_terminate!(ERROR_MESSAGE); + } + drop(lock); + } + #[no_mangle] pub unsafe extern "C" fn hostcall_bad_borrow( &mut vmctx, @@ -188,6 +221,33 @@ macro_rules! host_tests { assert!(HOSTCALL_MUTEX.is_poisoned()); } + #[test] + fn run_hostcall_nested_error_unwind() { + let module = + test_module_c("host", "hostcall_nested_error_unwind.c").expect("build and load module"); + let region = TestRegion::create(1, &Limits::default()).expect("region can be created"); + let mut inst = region + .new_instance(module) + .expect("instance can be created"); + + match inst.run("main", &[0u32.into(), 0u32.into()]) { + Err(Error::RuntimeTerminated(term)) => { + assert_eq!( + *term + .provided_details() + .expect("user provided termination reason") + .downcast_ref::<&'static str>() + .expect("error was static str"), + ERROR_MESSAGE + ); + } + res => panic!("unexpected result: {:?}", res), + } + + assert!(HOSTCALL_MUTEX_1.is_poisoned()); + assert!(HOSTCALL_MUTEX_2.is_poisoned()); + } + #[test] fn run_fpe() { let module = test_module_c("host", "fpe.c").expect("build and load module"); diff --git a/lucet-runtime/src/c_api.rs b/lucet-runtime/src/c_api.rs index 8d4f79652..7de9e222e 100644 --- a/lucet-runtime/src/c_api.rs +++ b/lucet-runtime/src/c_api.rs @@ -5,7 +5,6 @@ use lucet_runtime_internals::c_api::*; use lucet_runtime_internals::instance::{ instance_handle_from_raw, instance_handle_to_raw, InstanceInternal, }; -use lucet_runtime_internals::vmctx::VmctxInternal; use lucet_runtime_internals::WASM_PAGE_SIZE; use lucet_runtime_internals::{ assert_nonnull, lucet_hostcall_terminate, lucet_hostcalls, with_ffi_arcs,