Skip to content

Commit

Permalink
feat: Brillig pointer codegen and execution (#5737)
Browse files Browse the repository at this point in the history
Resolves noir-lang/noir#3907

This PR builds upon
#5709.

These changes do not yet include a Brillig stdlib and removal of the
`Brillig` opcode itself. The generated stdlib Brillig (such as for
quotient) is not created in the same manner as other Brillig calls which
are generated during SSA. I have decided to leave this for another
follow-up where we can actually remove `Brillig`.

---------

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
Co-authored-by: jfecher <jfecher11@gmail.com>
  • Loading branch information
3 people authored Apr 16, 2024
1 parent 67077a1 commit a7b9d20
Show file tree
Hide file tree
Showing 23 changed files with 615 additions and 88 deletions.
4 changes: 4 additions & 0 deletions noir/noir-repo/acvm-repo/acir/src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ impl std::fmt::Display for Program {
writeln!(f, "func {}", func_index)?;
writeln!(f, "{}", function)?;
}
for (func_index, function) in self.unconstrained_functions.iter().enumerate() {
writeln!(f, "unconstrained func {}", func_index)?;
writeln!(f, "{:?}", function.bytecode)?;
}
Ok(())
}
}
Expand Down
54 changes: 44 additions & 10 deletions noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use acir::{
brillig::{ForeignCallParam, ForeignCallResult},
brillig::{ForeignCallParam, ForeignCallResult, Opcode as BrilligOpcode},
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
opcodes::BlockId,
Expand Down Expand Up @@ -46,9 +46,9 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
pub(super) fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
outputs: &[BrilligOutputs],
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
for output in outputs {
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
Expand All @@ -63,6 +63,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
Ok(())
}

// TODO: Delete this old method once `Brillig` is deleted
/// Constructs a solver for a Brillig block given the bytecode and initial
/// witness.
pub(crate) fn new(
Expand All @@ -72,13 +73,45 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
bb_solver: &'b B,
acir_index: usize,
) -> Result<Self, OpcodeResolutionError> {
let vm = Self::setup_brillig_vm(
initial_witness,
memory,
&brillig.inputs,
&brillig.bytecode,
bb_solver,
)?;
Ok(Self { vm, acir_index })
}

/// Constructs a solver for a Brillig block given the bytecode and initial
/// witness.
pub(crate) fn new_call(
initial_witness: &WitnessMap,
memory: &HashMap<BlockId, MemoryOpSolver>,
inputs: &'b [BrilligInputs],
brillig_bytecode: &'b [BrilligOpcode],
bb_solver: &'b B,
acir_index: usize,
) -> Result<Self, OpcodeResolutionError> {
let vm =
Self::setup_brillig_vm(initial_witness, memory, inputs, brillig_bytecode, bb_solver)?;
Ok(Self { vm, acir_index })
}

fn setup_brillig_vm(
initial_witness: &WitnessMap,
memory: &HashMap<BlockId, MemoryOpSolver>,
inputs: &[BrilligInputs],
brillig_bytecode: &'b [BrilligOpcode],
bb_solver: &'b B,
) -> Result<VM<'b, B>, OpcodeResolutionError> {
// Set input values
let mut calldata: Vec<FieldElement> = Vec::new();
// Each input represents an expression or array of expressions to evaluate.
// Iterate over each input and evaluate the expression(s) associated with it.
// Push the results into memory.
// If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution.
for input in &brillig.inputs {
for input in inputs {
match input {
BrilligInputs::Single(expr) => match get_value(expr, initial_witness) {
Ok(value) => calldata.push(value),
Expand Down Expand Up @@ -118,8 +151,8 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {

// Instantiate a Brillig VM given the solved calldata
// along with the Brillig bytecode.
let vm = VM::new(calldata, &brillig.bytecode, vec![], bb_solver);
Ok(Self { vm, acir_index })
let vm = VM::new(calldata, brillig_bytecode, vec![], bb_solver);
Ok(vm)
}

pub fn get_memory(&self) -> &[MemoryValue] {
Expand Down Expand Up @@ -204,13 +237,13 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
pub(crate) fn finalize(
self,
witness: &mut WitnessMap,
brillig: &Brillig,
outputs: &[BrilligOutputs],
) -> Result<(), OpcodeResolutionError> {
// Finish the Brillig execution by writing the outputs to the witness map
let vm_status = self.vm.get_status();
match vm_status {
VMStatus::Finished { return_data_offset, return_data_size } => {
self.write_brillig_outputs(witness, return_data_offset, return_data_size, brillig)?;
self.write_brillig_outputs(witness, return_data_offset, return_data_size, outputs)?;
Ok(())
}
_ => panic!("Brillig VM has not completed execution"),
Expand All @@ -222,12 +255,12 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
witness_map: &mut WitnessMap,
return_data_offset: usize,
return_data_size: usize,
brillig: &Brillig,
outputs: &[BrilligOutputs],
) -> Result<(), OpcodeResolutionError> {
// Write VM execution results into the witness map
let memory = self.vm.get_memory();
let mut current_ret_data_idx = return_data_offset;
for output in brillig.outputs.iter() {
for output in outputs.iter() {
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, memory[current_ret_data_idx].to_field(), witness_map)?;
Expand All @@ -242,6 +275,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
}
}
}

assert!(
current_ret_data_idx == return_data_offset + return_data_size,
"Brillig VM did not write the expected number of return values"
Expand Down
72 changes: 64 additions & 8 deletions noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashMap;

use acir::{
brillig::ForeignCallResult,
circuit::{opcodes::BlockId, Opcode, OpcodeLocation},
circuit::{brillig::BrilligBytecode, opcodes::BlockId, Opcode, OpcodeLocation},
native_types::{Expression, Witness, WitnessMap},
BlackBoxFunc, FieldElement,
};
Expand Down Expand Up @@ -165,10 +165,18 @@ pub struct ACVM<'a, B: BlackBoxFunctionSolver> {
/// Represents the outputs of all ACIR calls during an ACVM process
/// List is appended onto by the caller upon reaching a [ACVMStatus::RequiresAcirCall]
acir_call_results: Vec<Vec<FieldElement>>,

// Each unconstrained function referenced in the program
unconstrained_functions: &'a [BrilligBytecode],
}

impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
pub fn new(backend: &'a B, opcodes: &'a [Opcode], initial_witness: WitnessMap) -> Self {
pub fn new(
backend: &'a B,
opcodes: &'a [Opcode],
initial_witness: WitnessMap,
unconstrained_functions: &'a [BrilligBytecode],
) -> Self {
let status = if opcodes.is_empty() { ACVMStatus::Solved } else { ACVMStatus::InProgress };
ACVM {
status,
Expand All @@ -181,6 +189,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
brillig_solver: None,
acir_call_counter: 0,
acir_call_results: Vec::default(),
unconstrained_functions,
}
}

Expand Down Expand Up @@ -324,9 +333,10 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
},
Opcode::BrilligCall { .. } => {
todo!("implement brillig pointer handling");
}
Opcode::BrilligCall { .. } => match self.solve_brillig_call_opcode() {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
},
Opcode::Call { .. } => match self.solve_call_opcode() {
Ok(Some(input_values)) => return self.wait_for_acir_call(input_values),
res => res.map(|_| ()),
Expand Down Expand Up @@ -381,7 +391,8 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {

let witness = &mut self.witness_map;
if is_predicate_false(witness, &brillig.predicate)? {
return BrilligSolver::<B>::zero_out_brillig_outputs(witness, brillig).map(|_| None);
return BrilligSolver::<B>::zero_out_brillig_outputs(witness, &brillig.outputs)
.map(|_| None);
}

// If we're resuming execution after resolving a foreign call then
Expand All @@ -407,7 +418,51 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
}
BrilligSolverStatus::Finished => {
// Write execution outputs
solver.finalize(witness, brillig)?;
solver.finalize(witness, &brillig.outputs)?;
Ok(None)
}
}
}

fn solve_brillig_call_opcode(
&mut self,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
let Opcode::BrilligCall { id, inputs, outputs, predicate } =
&self.opcodes[self.instruction_pointer]
else {
unreachable!("Not executing a Brillig opcode");
};

let witness = &mut self.witness_map;
if is_predicate_false(witness, predicate)? {
return BrilligSolver::<B>::zero_out_brillig_outputs(witness, outputs).map(|_| None);
}

// If we're resuming execution after resolving a foreign call then
// there will be a cached `BrilligSolver` to avoid recomputation.
let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() {
Some(solver) => solver,
None => BrilligSolver::new_call(
witness,
&self.block_solvers,
inputs,
&self.unconstrained_functions[*id as usize].bytecode,
self.backend,
self.instruction_pointer,
)?,
};
match solver.solve()? {
BrilligSolverStatus::ForeignCallWait(foreign_call) => {
// Cache the current state of the solver
self.brillig_solver = Some(solver);
Ok(Some(foreign_call))
}
BrilligSolverStatus::InProgress => {
unreachable!("Brillig solver still in progress")
}
BrilligSolverStatus::Finished => {
// Write execution outputs
solver.finalize(witness, outputs)?;
Ok(None)
}
}
Expand All @@ -425,7 +480,8 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
};

if should_skip {
let resolution = BrilligSolver::<B>::zero_out_brillig_outputs(witness, brillig);
let resolution =
BrilligSolver::<B>::zero_out_brillig_outputs(witness, &brillig.outputs);
return StepResult::Status(self.handle_opcode_resolution(resolution));
}

Expand Down
32 changes: 19 additions & 13 deletions noir/noir-repo/acvm-repo/acvm/tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ fn inversion_brillig_oracle_equivalence() {
(Witness(2), FieldElement::from(3u128)),
])
.into();

let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments);
let unconstrained_functions = vec![];
let mut acvm =
ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments, &unconstrained_functions);
// use the partial witness generation solver with our acir program
let solver_status = acvm.solve();

Expand Down Expand Up @@ -241,8 +242,9 @@ fn double_inversion_brillig_oracle() {
(Witness(9), FieldElement::from(10u128)),
])
.into();

let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments);
let unconstrained_functions = vec![];
let mut acvm =
ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments, &unconstrained_functions);

// use the partial witness generation solver with our acir program
let solver_status = acvm.solve();
Expand Down Expand Up @@ -370,8 +372,9 @@ fn oracle_dependent_execution() {

let witness_assignments =
BTreeMap::from([(w_x, FieldElement::from(2u128)), (w_y, FieldElement::from(2u128))]).into();

let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments);
let unconstrained_functions = vec![];
let mut acvm =
ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments, &unconstrained_functions);

// use the partial witness generation solver with our acir program
let solver_status = acvm.solve();
Expand Down Expand Up @@ -474,8 +477,9 @@ fn brillig_oracle_predicate() {
(Witness(2), FieldElement::from(3u128)),
])
.into();

let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments);
let unconstrained_functions = vec![];
let mut acvm =
ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments, &unconstrained_functions);
let solver_status = acvm.solve();
assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");

Expand Down Expand Up @@ -509,7 +513,8 @@ fn unsatisfied_opcode_resolved() {
values.insert(d, FieldElement::from(2_i128));

let opcodes = vec![Opcode::AssertZero(opcode_a)];
let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, values);
let unconstrained_functions = vec![];
let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, values, &unconstrained_functions);
let solver_status = acvm.solve();
assert_eq!(
solver_status,
Expand Down Expand Up @@ -591,8 +596,8 @@ fn unsatisfied_opcode_resolved_brillig() {
values.insert(w_result, FieldElement::from(0_i128));

let opcodes = vec![brillig_opcode, Opcode::AssertZero(opcode_a)];

let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, values);
let unconstrained_functions = vec![];
let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, values, &unconstrained_functions);
let solver_status = acvm.solve();
assert_eq!(
solver_status,
Expand Down Expand Up @@ -635,8 +640,9 @@ fn memory_operations() {
});

let opcodes = vec![init, read_op, expression];

let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, initial_witness);
let unconstrained_functions = vec![];
let mut acvm =
ACVM::new(&StubbedBlackBoxSolver, &opcodes, initial_witness, &unconstrained_functions);
let solver_status = acvm.solve();
assert_eq!(solver_status, ACVMStatus::Solved);
let witness_map = acvm.finalize();
Expand Down
Loading

0 comments on commit a7b9d20

Please sign in to comment.