Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore(mmu): add overflow handling #452

Merged
merged 12 commits into from
Nov 4, 2024
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "jolt"
version = "0.1.0"
authors = [
# author of original Spartan paper and code base
# author of original Spartan paper and code base
"Srinath Setty <srinath@microsoft.com>",
# authors who contributed to the Arkworks Spartan fork
"Zhenfei Zhang <zhenfei.zhang@hotmail.com>",
Expand Down Expand Up @@ -46,6 +46,10 @@ members = [
"examples/stdlib/guest",
"examples/muldiv",
"examples/muldiv/guest",
"examples/overflow",
"examples/overflow/guest",


]

[features]
Expand Down
8 changes: 8 additions & 0 deletions examples/overflow/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "overflow"
version = "0.1.0"
edition = "2021"

[dependencies]
jolt-sdk = { path = "../../jolt-sdk", features = ["host"] }
guest = { package = "overflow-guest", path = "./guest" }
10 changes: 10 additions & 0 deletions examples/overflow/guest/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "overflow-guest"
version = "0.1.0"
edition = "2021"

[features]
guest = []

[dependencies]
jolt = { package = "jolt-sdk", path = "../../../jolt-sdk" }
26 changes: 26 additions & 0 deletions examples/overflow/guest/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#![cfg_attr(feature = "guest", no_std)]

extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;

#[jolt::provable(stack_size = 1024)]
fn overflow_stack() -> u32 {
let arr = [1u32; 1024];
arr.iter().sum()
}

#[jolt::provable(stack_size = 8192)]
fn allocate_stack_with_increased_size() -> u32 {
overflow_stack()
}

#[jolt::provable(memory_size = 4096)]
fn overflow_heap() -> u32 {
let mut vectors = Vec::new();

loop {
let v = vec![1u32; 1024];
vectors.extend(v);
}
}
5 changes: 5 additions & 0 deletions examples/overflow/guest/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#![cfg_attr(feature = "guest", no_std)]
#![no_main]

#[allow(unused_imports)]
use overflow_guest::*;
40 changes: 40 additions & 0 deletions examples/overflow/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::any::Any;
use std::panic;

pub fn main() {
let (prove_overflow_stack, _) = guest::build_overflow_stack();

let res = panic::catch_unwind(|| {
// trying to allocate 1024 elems array and sum it up
// with stack_size=1024, should panic
let (_, _) = prove_overflow_stack();
});
handle_result(res);

// now lets try to overflow the heap, should also panic
let (prove_overflow_heap, _) = guest::build_overflow_heap();

let res = panic::catch_unwind(|| {
let (_, _) = prove_overflow_heap();
});
handle_result(res);

// valid case for stack allocation, calls overflow_stack() under the hood
// but with stack_size=8192
let (prove_allocate_stack_with_increased_size, verfiy_allocate_stack_with_increased_size) =
guest::build_allocate_stack_with_increased_size();

let (output, proof) = prove_allocate_stack_with_increased_size();
let is_valid = verfiy_allocate_stack_with_increased_size(proof);

println!("output: {}", output);
println!("valid: {}", is_valid);
}

fn handle_result(res: Result<(), Box<dyn Any + Send>>) {
if let Err(e) = &res {
if let Some(msg) = e.downcast_ref::<String>() {
println!("--> Panic occurred with message: {}\n", msg);
}
}
}
114 changes: 88 additions & 26 deletions tracer/src/emulator/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,38 @@ impl Mmu {
}
}

/// Asserts the validity of an effective memory address.
/// Panics if the address is invalid.
///
/// # Arguments
/// * `effective_address` Effective memory address to validate
#[inline]
fn assert_effective_address(&self, effective_address: u64) {
if effective_address < DRAM_BASE {
// less then DRAM_BASE and greater then panic => zero_padding region
assert!(
effective_address <= self.jolt_device.memory_layout.termination,
"Stack overflow: Attempted to write to 0x{:X}",
effective_address
);
// less then panic => jolt_device region (i.e. input/output)
assert!(
self.jolt_device.is_output(effective_address)
|| self.jolt_device.is_panic(effective_address)
|| self.jolt_device.is_termination(effective_address),
"Unknown memory mapping: 0x{:X}",
effective_address
);
} else {
// greater then memory capacity
assert!(
self.memory.validate_address(effective_address),
"Heap overflow: Attempted to write to 0x{:X}",
effective_address
);
}
}

/// Fetches an instruction byte. This method takes virtual address
/// and translates into physical address inside.
///
Expand Down Expand Up @@ -540,24 +572,12 @@ impl Mmu {
}

fn trace_store(&mut self, effective_address: u64, value: u64) {
if effective_address < DRAM_BASE {
if self.jolt_device.is_output(effective_address)
|| self.jolt_device.is_panic(effective_address)
|| self.jolt_device.is_termination(effective_address)
{
self.tracer.push_memory(MemoryState::Write {
address: effective_address,
post_value: value,
});
} else {
panic!("Unknown memory mapping {:X}.", effective_address);
}
} else {
self.tracer.push_memory(MemoryState::Write {
address: effective_address,
post_value: value,
});
}
self.assert_effective_address(effective_address);

self.tracer.push_memory(MemoryState::Write {
address: effective_address,
post_value: value,
});
}

/// Loads two bytes from main memory or peripheral devices depending on
Expand Down Expand Up @@ -643,14 +663,8 @@ impl Mmu {
0x10000000..=0x100000ff => self.uart.store(effective_address, value),
0x10001000..=0x10001FFF => self.disk.store(effective_address, value),
_ => {
if self.jolt_device.is_output(effective_address)
|| self.jolt_device.is_panic(effective_address)
|| self.jolt_device.is_termination(effective_address)
{
self.jolt_device.store(effective_address, value);
} else {
panic!("Unknown memory mapping {:X}.", effective_address);
}
self.assert_effective_address(effective_address);
self.jolt_device.store(effective_address, value);
}
},
};
Expand Down Expand Up @@ -1119,3 +1133,51 @@ impl MemoryWrapper {
self.memory.validate_address(address - DRAM_BASE)
}
}

#[cfg(test)]
mod test_mmu {
use super::*;
use crate::emulator::terminal::DummyTerminal;
use std::rc::Rc;

const MEM_CAPACITY: u64 = 1024 * 1024;

fn setup_mmu(capacity: u64) -> Mmu {
let terminal = Box::new(DummyTerminal::new());
let tracer = Rc::new(Tracer::new());
let mut mmu = Mmu::new(Xlen::Bit64, terminal, tracer);

mmu.init_memory(capacity);

mmu
}

#[test]
#[should_panic(expected = "Heap overflow")]
fn test_heap_overflow() {
let mut mmu = setup_mmu(MEM_CAPACITY);

// Try to write beyond the allocated memory
let overflow_address = DRAM_BASE + MEM_CAPACITY + 1;
mmu.trace_store(overflow_address, 0xc50513);
}

#[test]
#[should_panic(expected = "Stack overflow")]
fn test_stack_overflow() {
let mut mmu = setup_mmu(MEM_CAPACITY);

// Try to write to an address below DRAM_BASE
let invalid_address = DRAM_BASE - 1;
mmu.trace_store(invalid_address, 0xc50513);
}

#[test]
#[should_panic(expected = "Unknown memory mapping")]
fn test_unknown_memory_mapping() {
let mut mmu = setup_mmu(MEM_CAPACITY);

let invalid_address = 1234;
mmu.trace_store(invalid_address, 0xc50513);
}
}
Loading