Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: break out helper methods for writing foreign call results #5181

Merged
merged 3 commits into from
Jun 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 62 additions & 46 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@
self.set_program_counter(*location)
}
Opcode::Const { destination, value, bit_size } => {
// Consts are not checked in runtime to fit in the bit size, since they can safely be checked statically.

Check warning on line 339 in acvm-repo/brillig_vm/src/lib.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Consts)
self.memory.write(*destination, MemoryValue::new_from_field(*value, *bit_size));
self.increment_program_counter()
}
Expand Down Expand Up @@ -468,7 +468,7 @@
destination_value_types: &[HeapValueType],
foreign_call_index: usize,
) -> Result<(), String> {
let values = &self.foreign_call_results[foreign_call_index].values;
let values = std::mem::take(&mut self.foreign_call_results[foreign_call_index].values);

if destinations.len() != values.len() {
return Err(format!(
Expand All @@ -479,22 +479,13 @@
}

for ((destination, value_type), output) in
destinations.iter().zip(destination_value_types).zip(values)
destinations.iter().zip(destination_value_types).zip(&values)
{
match (destination, value_type) {
(ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => {
match output {
ForeignCallParam::Single(value) => {
let memory_value = MemoryValue::new_checked(*value, *bit_size);
if let Some(memory_value) = memory_value {
self.memory.write(*value_index, memory_value);
} else {
return Err(format!(
"Foreign call result value {} does not fit in bit size {}",
value,
bit_size
));
}
self.write_value_to_memory(*value_index, value, *bit_size)?;
}
_ => return Err(format!(
"Function result size does not match brillig bytecode. Expected 1 result but got {output:?}")
Expand All @@ -506,28 +497,12 @@
HeapValueType::Array { value_types, size: type_size },
) if size == type_size => {
if HeapValueType::all_simple(value_types) {
let bit_sizes_iterator = value_types.iter().map(|typ| match typ {
HeapValueType::Simple(bit_size) => *bit_size,
_ => unreachable!("Expected simple value type"),
}).cycle();
match output {
match output {
ForeignCallParam::Array(values) => {
if values.len() != *size {
return Err("Foreign call result array doesn't match expected size".to_string());
}
// Convert the destination pointer to a usize
let destination = self.memory.read_ref(*pointer_index);
// Write to our destination memory
let memory_values: Option<Vec<_>> = values.iter().zip(bit_sizes_iterator).map(
|(value, bit_size)| MemoryValue::new_checked(*value, bit_size)).collect();
if let Some(memory_values) = memory_values {
self.memory.write_slice(destination, &memory_values);
} else {
return Err(format!(
"Foreign call result values {:?} do not match expected bit sizes",
values,
));
}
self.write_values_to_memory_slice(*pointer_index, values, value_types)?;
}
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
Expand All @@ -542,26 +517,12 @@
HeapValueType::Vector { value_types },
) => {
if HeapValueType::all_simple(value_types) {
let bit_sizes_iterator = value_types.iter().map(|typ| match typ {
HeapValueType::Simple(bit_size) => *bit_size,
_ => unreachable!("Expected simple value type"),
}).cycle();
match output {
ForeignCallParam::Array(values) => {
// Set our size in the size address
self.memory.write(*size_index, values.len().into());
// Convert the destination pointer to a usize
let destination = self.memory.read_ref(*pointer_index);
// Write to our destination memory
let memory_values: Option<Vec<_>> = values.iter().zip(bit_sizes_iterator).map(|(value, bit_size)| MemoryValue::new_checked(*value, bit_size)).collect();
if let Some(memory_values) = memory_values {
self.memory.write_slice(destination, &memory_values);
}else{
return Err(format!(
"Foreign call result values {:?} do not match expected bit sizes",
values,
));
}

self.write_values_to_memory_slice(*pointer_index, values, value_types)?;
}
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
Expand All @@ -577,6 +538,61 @@
}
}

let _ =
std::mem::replace(&mut self.foreign_call_results[foreign_call_index].values, values);

Ok(())
}

fn write_value_to_memory(
&mut self,
destination: MemoryAddress,
value: &F,
value_bit_size: u32,
) -> Result<(), String> {
let memory_value = MemoryValue::new_checked(*value, value_bit_size);

if let Some(memory_value) = memory_value {
self.memory.write(destination, memory_value);
} else {
return Err(format!(
"Foreign call result value {} does not fit in bit size {}",
value, value_bit_size
));
}
Ok(())
}

fn write_values_to_memory_slice(
&mut self,
pointer_index: MemoryAddress,
values: &[F],
value_types: &[HeapValueType],
) -> Result<(), String> {
let bit_sizes_iterator = value_types
.iter()
.map(|typ| match typ {
HeapValueType::Simple(bit_size) => *bit_size,
_ => unreachable!("Expected simple value type"),
})
.cycle();

// Convert the destination pointer to a usize
let destination = self.memory.read_ref(pointer_index);
// Write to our destination memory
let memory_values: Option<Vec<_>> = values
.iter()
.zip(bit_sizes_iterator)
.map(|(value, bit_size)| MemoryValue::new_checked(*value, bit_size))
.collect();
if let Some(memory_values) = memory_values {
self.memory.write_slice(destination, &memory_values);
} else {
return Err(format!(
"Foreign call result values {:?} do not match expected bit sizes",
values,
));
}
Ok(())
}

Expand Down Expand Up @@ -709,7 +725,7 @@
}

#[test]
fn jmpifnot_opcode() {

Check warning on line 728 in acvm-repo/brillig_vm/src/lib.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (jmpifnot)
let calldata: Vec<FieldElement> = vec![1u128.into(), 2u128.into()];

let calldata_copy = Opcode::CalldataCopy {
Expand Down Expand Up @@ -844,7 +860,7 @@
}

#[test]
fn cmov_opcode() {

Check warning on line 863 in acvm-repo/brillig_vm/src/lib.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (cmov)
let calldata: Vec<FieldElement> =
vec![(0u128).into(), (1u128).into(), (2u128).into(), (3u128).into()];

Expand Down
Loading