diff --git a/src/snapshot2.rs b/src/snapshot2.rs index 1fe8bb23..27bf471a 100644 --- a/src/snapshot2.rs +++ b/src/snapshot2.rs @@ -2,7 +2,7 @@ use crate::{ bits::roundup, elf::{LoadingAction, ProgramMetadata}, machine::SupportMachine, - memory::{Memory, FLAG_DIRTY}, + memory::{get_page_indices, Memory, FLAG_DIRTY}, Error, Register, RISCV_GENERAL_REGISTER_NUMBER, RISCV_PAGESIZE, }; use bytes::Bytes; @@ -125,6 +125,7 @@ impl> Snapshot2Context { length: u64, ) -> Result<(u64, u64), Error> { let (data, full_length) = self.data_source.load_data(id, offset, length)?; + self.untrack_pages(machine, addr, data.len() as u64)?; machine.memory_mut().store_bytes(addr, &data)?; self.track_pages(machine, addr, data.len() as u64, id, offset)?; Ok((data.len() as u64, full_length)) @@ -231,7 +232,7 @@ impl> Snapshot2Context { self.track_pages(machine, start, length, id, offset + action.source.start) } - /// This is only made public for advanced usages, but make sure to exercise more + /// The followings are only made public for advanced usages, but make sure to exercise more /// cautions when calling it! pub fn track_pages( &mut self, @@ -259,6 +260,20 @@ impl> Snapshot2Context { } Ok(()) } + + pub fn untrack_pages( + &mut self, + machine: &mut M, + start: u64, + length: u64, + ) -> Result<(), Error> { + let page_indices = get_page_indices(start, length); + for page in page_indices.0..=page_indices.1 { + machine.memory_mut().set_flag(page, FLAG_DIRTY)?; + self.pages.remove(&page); + } + Ok(()) + } } #[derive(Clone, Debug, Deserialize, Serialize)] diff --git a/tests/test_resume2.rs b/tests/test_resume2.rs index 5a9ee042..99cda4c2 100644 --- a/tests/test_resume2.rs +++ b/tests/test_resume2.rs @@ -309,7 +309,11 @@ fn load_program(name: &str) -> TestSource { file.read_to_end(&mut buffer).unwrap(); let program = buffer.into(); - let data = vec![7; 16 * 4096]; + let mut data = vec![0; 16 * 4096]; + for i in 0..data.len() { + data[i] = i as u8; + } + let mut m = HashMap::default(); m.insert(DATA_ID, data.into()); m.insert(PROGRAM_ID, program); @@ -622,3 +626,36 @@ pub fn test_sc_after_snapshot2() { assert!(result2.is_ok()); assert_eq!(result2.unwrap(), 0); } + +#[cfg(not(feature = "enable-chaos-mode-by-default"))] +#[test] +pub fn test_store_bytes_twice() { + let data_source = load_program("tests/programs/sc_after_snapshot"); + + let mut machine = MachineTy::Asm.build(data_source.clone(), VERSION2); + machine.set_max_cycles(u64::MAX); + machine.load_program(&vec!["main".into()]).unwrap(); + + match machine { + Machine::Asm(ref mut inner, ref ctx) => { + ctx.lock() + .unwrap() + .store_bytes(&mut inner.machine, 0, &DATA_ID, 2, 29186) + .unwrap(); + ctx.lock() + .unwrap() + .store_bytes(&mut inner.machine, 0, &DATA_ID, 0, 11008) + .unwrap(); + } + _ => unimplemented!(), + } + let a = machine.full_memory().unwrap()[4096 * 2]; + + let snapshot = machine.snapshot().unwrap(); + let mut machine2 = MachineTy::Asm.build(data_source.clone(), VERSION2); + machine2.resume(snapshot).unwrap(); + machine2.set_max_cycles(u64::MAX); + let b = machine2.full_memory().unwrap()[4096 * 2]; + + assert_eq!(a, b); +}