Skip to content

Commit

Permalink
Merge 9e1c171 into d6122eb
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored May 31, 2024
2 parents d6122eb + 9e1c171 commit f10b38e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 68 deletions.
13 changes: 3 additions & 10 deletions compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl AcirContext {
}

/// Converts an [`AcirVar`] to a [`Witness`]
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
pub(crate) fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
let expression = self.var_to_expression(var)?;
let witness = if let Some(constant) = expression.to_const() {
// Check if a witness has been assigned this value already, if so reuse it.
Expand Down Expand Up @@ -1027,15 +1027,6 @@ impl AcirContext {
Ok(remainder)
}

/// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the
/// `GeneratedAcir`'s return witnesses.
pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> {
let return_var = self.get_or_create_witness_var(acir_var)?;
let witness = self.var_to_witness(return_var)?;
self.acir_ir.push_return_witness(witness);
Ok(())
}

/// Constrains the `AcirVar` variable to be of type `NumericType`.
pub(crate) fn range_constrain_var(
&mut self,
Expand Down Expand Up @@ -1538,9 +1529,11 @@ impl AcirContext {
pub(crate) fn finish(
mut self,
inputs: Vec<Witness>,
return_values: Vec<Witness>,
warnings: Vec<SsaReport>,
) -> GeneratedAcir {
self.acir_ir.input_witnesses = inputs;
self.acir_ir.return_witnesses = return_values;
self.acir_ir.warnings = warnings;
self.acir_ir
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ pub(crate) struct GeneratedAcir {
opcodes: Vec<AcirOpcode<FieldElement>>,

/// All witness indices that comprise the final return value of the program
///
/// Note: This may contain repeated indices, which is necessary for later mapping into the
/// abi's return type.
pub(crate) return_witnesses: Vec<Witness>,

/// All witness indices which are inputs to the main function
Expand Down Expand Up @@ -164,11 +161,6 @@ impl GeneratedAcir {

fresh_witness
}

/// Adds a witness index to the program's return witnesses.
pub(crate) fn push_return_witness(&mut self, witness: Witness) {
self.return_witnesses.push(witness);
}
}

impl GeneratedAcir {
Expand Down
122 changes: 72 additions & 50 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ use acvm::acir::circuit::brillig::BrilligBytecode;
use acvm::acir::circuit::{AssertionPayload, ErrorSelector, OpcodeLocation};
use acvm::acir::native_types::Witness;
use acvm::acir::BlackBoxFunc;
use acvm::{
acir::AcirField,
acir::{circuit::opcodes::BlockId, native_types::Expression},
FieldElement,
};
use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement};
use fxhash::FxHashMap as HashMap;
use im::Vector;
use iter_extended::{try_vecmap, vecmap};
Expand Down Expand Up @@ -330,38 +326,10 @@ impl Ssa {
bytecode: brillig.byte_code,
});

let runtime_types = self.functions.values().map(|function| function.runtime());
for (acir, runtime_type) in acirs.iter_mut().zip(runtime_types) {
if matches!(runtime_type, RuntimeType::Acir(_)) {
generate_distinct_return_witnesses(acir);
}
}

Ok((acirs, brillig, self.error_selector_to_type))
}
}

fn generate_distinct_return_witnesses(acir: &mut GeneratedAcir) {
// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
// layout for serializing those types as if they were being passed as inputs.
//
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
// working with rather than following the standard ABI encoding rules.
//
// TODO: We're being conservative here by generating a new witness for every expression.
// This means that we're likely to get a number of constraints which are just renumbering witnesses.
// This can be tackled by:
// - Tracking the last assigned public input witness and only renumbering a witness if it is below this value.
// - Modifying existing constraints to rearrange their outputs so they are suitable
// - See: https://github.com/noir-lang/noir/pull/4467
let distinct_return_witness = vecmap(acir.return_witnesses.clone(), |return_witness| {
acir.create_witness_for_expression(&Expression::from(return_witness))
});

acir.return_witnesses = distinct_return_witness;
}

impl<'a> Context<'a> {
fn new(shared_context: &'a mut SharedContext) -> Context<'a> {
let mut acir_context = AcirContext::default();
Expand Down Expand Up @@ -422,15 +390,45 @@ impl<'a> Context<'a> {
let dfg = &main_func.dfg;
let entry_block = &dfg[main_func.entry_block()];
let input_witness = self.convert_ssa_block_params(entry_block.parameters(), dfg)?;
let num_return_witnesses =
self.get_num_return_witnesses(entry_block.unwrap_terminator(), dfg);

// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
// layout for serializing those types as if they were being passed as inputs.
//
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
// working with rather than following the standard ABI encoding rules.
//
// We allocate these witnesses now before performing ACIR gen for the rest of the program as the location of
// the function's return values can then be determined through knowledge of its ABI alone.
let return_witness_vars =
vecmap(0..num_return_witnesses, |_| self.acir_context.add_variable());

let return_witnesses = vecmap(&return_witness_vars, |return_var| {
let expr = self.acir_context.var_to_expression(*return_var).unwrap();
expr.to_witness().expect("return vars should be witnesses")
});

self.data_bus = dfg.data_bus.to_owned();
let mut warnings = Vec::new();
for instruction_id in entry_block.instructions() {
warnings.extend(self.convert_ssa_instruction(*instruction_id, dfg, ssa, brillig)?);
}

warnings.extend(self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?);
Ok(self.acir_context.finish(input_witness, warnings))
let (return_vars, return_warnings) =
self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?;

// TODO: This is a naive method of assigning the return values to their witnesses as
// we're likely to get a number of constraints which are asserting one witness to be equal to another.
//
// We should search through the program and relabel these witnesses so we can remove this constraint.
for (witness_var, return_var) in return_witness_vars.iter().zip(return_vars) {
self.acir_context.assert_eq_var(*witness_var, return_var, None)?;
}

warnings.extend(return_warnings);
Ok(self.acir_context.finish(input_witness, return_witnesses, warnings))
}

fn convert_brillig_main(
Expand Down Expand Up @@ -468,17 +466,13 @@ impl<'a> Context<'a> {
)?;
self.shared_context.insert_generated_brillig(main_func.id(), arguments, 0, code);

let output_vars: Vec<_> = output_values
let return_witnesses: Vec<Witness> = output_values
.iter()
.flat_map(|value| value.clone().flatten())
.map(|value| value.0)
.collect();
.map(|(value, _)| self.acir_context.var_to_witness(value))
.collect::<Result<_, _>>()?;

for acir_var in output_vars {
self.acir_context.return_var(acir_var)?;
}

let generated_acir = self.acir_context.finish(witness_inputs, Vec::new());
let generated_acir = self.acir_context.finish(witness_inputs, return_witnesses, Vec::new());

assert_eq!(
generated_acir.opcodes().len(),
Expand Down Expand Up @@ -1724,12 +1718,39 @@ impl<'a> Context<'a> {
self.define_result(dfg, instruction, AcirValue::Var(result, typ));
}

/// Converts an SSA terminator's return values into their ACIR representations
fn get_num_return_witnesses(
&mut self,
terminator: &TerminatorInstruction,
dfg: &DataFlowGraph,
) -> usize {
let return_values = match terminator {
TerminatorInstruction::Return { return_values, .. } => return_values,
// TODO(https://github.com/noir-lang/noir/issues/4616): Enable recursion on foldable/non-inlined ACIR functions
_ => unreachable!("ICE: Program must have a singular return"),
};

return_values.iter().fold(0, |acc, value_id| {
let is_databus = self
.data_bus
.return_data
.map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]);

if is_databus {
// We do not return value for the data bus.
acc
} else {
acc + dfg.type_of_value(*value_id).flattened_size()
}
})
}

/// Converts an SSA terminator's return values into their ACIR representations
fn convert_ssa_return(
&mut self,
terminator: &TerminatorInstruction,
dfg: &DataFlowGraph,
) -> Result<Vec<SsaReport>, RuntimeError> {
) -> Result<(Vec<AcirVar>, Vec<SsaReport>), RuntimeError> {
let (return_values, call_stack) = match terminator {
TerminatorInstruction::Return { return_values, call_stack } => {
(return_values, call_stack.clone())
Expand All @@ -1739,6 +1760,7 @@ impl<'a> Context<'a> {
};

let mut has_constant_return = false;
let mut return_vars: Vec<AcirVar> = Vec::new();
for value_id in return_values {
let is_databus = self
.data_bus
Expand All @@ -1759,7 +1781,7 @@ impl<'a> Context<'a> {
dfg,
)?;
} else {
self.acir_context.return_var(acir_var)?;
return_vars.push(acir_var);
}
}
}
Expand All @@ -1770,7 +1792,7 @@ impl<'a> Context<'a> {
Vec::new()
};

Ok(warnings)
Ok((return_vars, warnings))
}

/// Gets the cached `AcirVar` that was converted from the corresponding `ValueId`. If it does
Expand Down Expand Up @@ -3079,8 +3101,8 @@ mod test {
check_call_opcode(
&func_with_nested_call_opcodes[1],
2,
vec![Witness(2), Witness(1)],
vec![Witness(3)],
vec![Witness(3), Witness(1)],
vec![Witness(4)],
);
}

Expand All @@ -3100,13 +3122,13 @@ mod test {
for (expected_input, input) in expected_inputs.iter().zip(inputs) {
assert_eq!(
expected_input, input,
"Expected witness {expected_input:?} but got {input:?}"
"Expected input witness {expected_input:?} but got {input:?}"
);
}
for (expected_output, output) in expected_outputs.iter().zip(outputs) {
assert_eq!(
expected_output, output,
"Expected witness {expected_output:?} but got {output:?}"
"Expected output witness {expected_output:?} but got {output:?}"
);
}
}
Expand Down

0 comments on commit f10b38e

Please sign in to comment.