diff --git a/acvm-repo/acvm/src/compiler/mod.rs b/acvm-repo/acvm/src/compiler/mod.rs index 26963eee58d..5b090e4603f 100644 --- a/acvm-repo/acvm/src/compiler/mod.rs +++ b/acvm-repo/acvm/src/compiler/mod.rs @@ -1,12 +1,7 @@ use acir::{ - circuit::{ - brillig::BrilligOutputs, directives::Directive, opcodes::UnsupportedMemoryOpcode, Circuit, - Opcode, OpcodeLocation, - }, - native_types::{Expression, Witness}, - BlackBoxFunc, FieldElement, + circuit::{opcodes::UnsupportedMemoryOpcode, Circuit, Opcode, OpcodeLocation}, + BlackBoxFunc, }; -use indexmap::IndexMap; use thiserror::Error; use crate::Language; @@ -15,8 +10,9 @@ use crate::Language; mod optimizers; mod transformers; -use optimizers::{GeneralOptimizer, RangeOptimizer}; -use transformers::{CSatTransformer, FallbackTransformer, R1CSTransformer}; +pub use optimizers::optimize; +pub use transformers::transform; +use transformers::transform_internal; #[derive(PartialEq, Eq, Debug, Error)] pub enum CompileError { @@ -77,204 +73,7 @@ pub fn compile( np_language: Language, is_opcode_supported: impl Fn(&Opcode) -> bool, ) -> Result<(Circuit, AcirTransformationMap), CompileError> { - // Instantiate the optimizer. - // Currently the optimizer and reducer are one in the same - // for CSAT + let (acir, AcirTransformationMap { acir_opcode_positions }) = optimize(acir); - // Track original acir opcode positions throughout the transformation passes of the compilation - // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert) - let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect(); - - // Fallback transformer pass - let (acir, acir_opcode_positions) = - FallbackTransformer::transform(acir, is_opcode_supported, acir_opcode_positions)?; - - // General optimizer pass - let mut opcodes: Vec = Vec::new(); - for opcode in acir.opcodes { - match opcode { - Opcode::Arithmetic(arith_expr) => { - opcodes.push(Opcode::Arithmetic(GeneralOptimizer::optimize(arith_expr))); - } - other_opcode => opcodes.push(other_opcode), - }; - } - let acir = Circuit { opcodes, ..acir }; - - // Range optimization pass - let range_optimizer = RangeOptimizer::new(acir); - let (mut acir, acir_opcode_positions) = - range_optimizer.replace_redundant_ranges(acir_opcode_positions); - - let mut transformer = match &np_language { - crate::Language::R1CS => { - let transformation_map = AcirTransformationMap { acir_opcode_positions }; - acir.assert_messages = - transform_assert_messages(acir.assert_messages, &transformation_map); - let transformer = R1CSTransformer::new(acir); - return Ok((transformer.transform(), transformation_map)); - } - crate::Language::PLONKCSat { width } => { - let mut csat = CSatTransformer::new(*width); - for value in acir.circuit_arguments() { - csat.mark_solvable(value); - } - csat - } - }; - - // TODO: the code below is only for CSAT transformer - // TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs - // TODO or at the very least, we could put all of it inside of CSatOptimizer pass - - let mut new_acir_opcode_positions: Vec = Vec::with_capacity(acir_opcode_positions.len()); - // Optimize the arithmetic gates by reducing them into the correct width and - // creating intermediate variables when necessary - let mut transformed_opcodes = Vec::new(); - - let mut next_witness_index = acir.current_witness_index + 1; - // maps a normalized expression to the intermediate variable which represents the expression, along with its 'norm' - // the 'norm' is simply the value of the first non zero coefficient in the expression, taken from the linear terms, or quadratic terms if there is none. - let mut intermediate_variables: IndexMap = IndexMap::new(); - for (index, opcode) in acir.opcodes.iter().enumerate() { - match opcode { - Opcode::Arithmetic(arith_expr) => { - let len = intermediate_variables.len(); - - let arith_expr = transformer.transform( - arith_expr.clone(), - &mut intermediate_variables, - &mut next_witness_index, - ); - - // Update next_witness counter - next_witness_index += (intermediate_variables.len() - len) as u32; - let mut new_opcodes = Vec::new(); - for (g, (norm, w)) in intermediate_variables.iter().skip(len) { - // de-normalize - let mut intermediate_opcode = g * *norm; - // constrain the intermediate opcode to the intermediate variable - intermediate_opcode.linear_combinations.push((-FieldElement::one(), *w)); - intermediate_opcode.sort(); - new_opcodes.push(intermediate_opcode); - } - new_opcodes.push(arith_expr); - for opcode in new_opcodes { - new_acir_opcode_positions.push(acir_opcode_positions[index]); - transformed_opcodes.push(Opcode::Arithmetic(opcode)); - } - } - Opcode::BlackBoxFuncCall(func) => { - match func { - acir::circuit::opcodes::BlackBoxFuncCall::AND { output, .. } - | acir::circuit::opcodes::BlackBoxFuncCall::XOR { output, .. } => { - transformer.mark_solvable(*output); - } - acir::circuit::opcodes::BlackBoxFuncCall::RANGE { .. } => (), - acir::circuit::opcodes::BlackBoxFuncCall::SHA256 { outputs, .. } - | acir::circuit::opcodes::BlackBoxFuncCall::Keccak256 { outputs, .. } - | acir::circuit::opcodes::BlackBoxFuncCall::Keccak256VariableLength { - outputs, - .. - } - | acir::circuit::opcodes::BlackBoxFuncCall::RecursiveAggregation { - output_aggregation_object: outputs, - .. - } - | acir::circuit::opcodes::BlackBoxFuncCall::Blake2s { outputs, .. } => { - for witness in outputs { - transformer.mark_solvable(*witness); - } - } - acir::circuit::opcodes::BlackBoxFuncCall::FixedBaseScalarMul { - outputs, - .. - } - | acir::circuit::opcodes::BlackBoxFuncCall::Pedersen { outputs, .. } => { - transformer.mark_solvable(outputs.0); - transformer.mark_solvable(outputs.1); - } - acir::circuit::opcodes::BlackBoxFuncCall::HashToField128Security { - output, - .. - } - | acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256k1 { output, .. } - | acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256r1 { output, .. } - | acir::circuit::opcodes::BlackBoxFuncCall::SchnorrVerify { output, .. } => { - transformer.mark_solvable(*output); - } - } - - new_acir_opcode_positions.push(acir_opcode_positions[index]); - transformed_opcodes.push(opcode.clone()); - } - Opcode::Directive(directive) => { - match directive { - Directive::Quotient(quotient_directive) => { - transformer.mark_solvable(quotient_directive.q); - transformer.mark_solvable(quotient_directive.r); - } - Directive::ToLeRadix { b, .. } => { - for witness in b { - transformer.mark_solvable(*witness); - } - } - Directive::PermutationSort { bits, .. } => { - for witness in bits { - transformer.mark_solvable(*witness); - } - } - } - new_acir_opcode_positions.push(acir_opcode_positions[index]); - transformed_opcodes.push(opcode.clone()); - } - Opcode::MemoryInit { .. } => { - // `MemoryInit` does not write values to the `WitnessMap` - new_acir_opcode_positions.push(acir_opcode_positions[index]); - transformed_opcodes.push(opcode.clone()); - } - Opcode::MemoryOp { op, .. } => { - for (_, witness1, witness2) in &op.value.mul_terms { - transformer.mark_solvable(*witness1); - transformer.mark_solvable(*witness2); - } - for (_, witness) in &op.value.linear_combinations { - transformer.mark_solvable(*witness); - } - new_acir_opcode_positions.push(acir_opcode_positions[index]); - transformed_opcodes.push(opcode.clone()); - } - Opcode::Brillig(brillig) => { - for output in &brillig.outputs { - match output { - BrilligOutputs::Simple(w) => transformer.mark_solvable(*w), - BrilligOutputs::Array(v) => { - for witness in v { - transformer.mark_solvable(*witness); - } - } - } - } - new_acir_opcode_positions.push(acir_opcode_positions[index]); - transformed_opcodes.push(opcode.clone()); - } - } - } - - let current_witness_index = next_witness_index - 1; - - let transformation_map = - AcirTransformationMap { acir_opcode_positions: new_acir_opcode_positions }; - - let acir = Circuit { - current_witness_index, - opcodes: transformed_opcodes, - // The optimizer does not add new public inputs - private_parameters: acir.private_parameters, - public_parameters: acir.public_parameters, - return_values: acir.return_values, - assert_messages: transform_assert_messages(acir.assert_messages, &transformation_map), - }; - - Ok((acir, transformation_map)) + transform_internal(acir, np_language, is_opcode_supported, acir_opcode_positions) } diff --git a/acvm-repo/acvm/src/compiler/optimizers/mod.rs b/acvm-repo/acvm/src/compiler/optimizers/mod.rs index cde7bdd2064..f714a682c3e 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/mod.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/mod.rs @@ -1,5 +1,37 @@ +use acir::circuit::{Circuit, Opcode}; + mod general; mod redundant_range; pub(crate) use general::GeneralOptimizer; pub(crate) use redundant_range::RangeOptimizer; + +use super::AcirTransformationMap; + +/// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] independent optimizations to a [`Circuit`]. +pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) { + // General optimizer pass + let mut opcodes: Vec = Vec::new(); + for opcode in acir.opcodes { + match opcode { + Opcode::Arithmetic(arith_expr) => { + opcodes.push(Opcode::Arithmetic(GeneralOptimizer::optimize(arith_expr))); + } + other_opcode => opcodes.push(other_opcode), + }; + } + let acir = Circuit { opcodes, ..acir }; + + // Track original acir opcode positions throughout the transformation passes of the compilation + // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert) + let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect(); + + // Range optimization pass + let range_optimizer = RangeOptimizer::new(acir); + let (acir, acir_opcode_positions) = + range_optimizer.replace_redundant_ranges(acir_opcode_positions); + + let transformation_map = AcirTransformationMap { acir_opcode_positions }; + + (acir, transformation_map) +} diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index 89e17ca68d0..b909bc54662 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -1,3 +1,12 @@ +use acir::{ + circuit::{brillig::BrilligOutputs, directives::Directive, Circuit, Opcode}, + native_types::{Expression, Witness}, + FieldElement, +}; +use indexmap::IndexMap; + +use crate::Language; + mod csat; mod fallback; mod r1cs; @@ -5,3 +14,204 @@ mod r1cs; pub(crate) use csat::CSatTransformer; pub(crate) use fallback::FallbackTransformer; pub(crate) use r1cs::R1CSTransformer; + +use super::{transform_assert_messages, AcirTransformationMap, CompileError}; + +/// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. +pub fn transform( + acir: Circuit, + np_language: Language, + is_opcode_supported: impl Fn(&Opcode) -> bool, +) -> Result<(Circuit, AcirTransformationMap), CompileError> { + // Track original acir opcode positions throughout the transformation passes of the compilation + // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert) + let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect(); + + transform_internal(acir, np_language, is_opcode_supported, acir_opcode_positions) +} + +/// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. +/// +/// Accepts an injected `acir_opcode_positions` to allow transformations to be applied directly after optimizations. +pub(super) fn transform_internal( + acir: Circuit, + np_language: Language, + is_opcode_supported: impl Fn(&Opcode) -> bool, + acir_opcode_positions: Vec, +) -> Result<(Circuit, AcirTransformationMap), CompileError> { + // Fallback transformer pass + let (mut acir, acir_opcode_positions) = + FallbackTransformer::transform(acir, is_opcode_supported, acir_opcode_positions)?; + + let mut transformer = match &np_language { + crate::Language::R1CS => { + let transformation_map = AcirTransformationMap { acir_opcode_positions }; + acir.assert_messages = + transform_assert_messages(acir.assert_messages, &transformation_map); + let transformer = R1CSTransformer::new(acir); + return Ok((transformer.transform(), transformation_map)); + } + crate::Language::PLONKCSat { width } => { + let mut csat = CSatTransformer::new(*width); + for value in acir.circuit_arguments() { + csat.mark_solvable(value); + } + csat + } + }; + + // TODO: the code below is only for CSAT transformer + // TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs + // TODO or at the very least, we could put all of it inside of CSatOptimizer pass + + let mut new_acir_opcode_positions: Vec = Vec::with_capacity(acir_opcode_positions.len()); + // Optimize the arithmetic gates by reducing them into the correct width and + // creating intermediate variables when necessary + let mut transformed_opcodes = Vec::new(); + + let mut next_witness_index = acir.current_witness_index + 1; + // maps a normalized expression to the intermediate variable which represents the expression, along with its 'norm' + // the 'norm' is simply the value of the first non zero coefficient in the expression, taken from the linear terms, or quadratic terms if there is none. + let mut intermediate_variables: IndexMap = IndexMap::new(); + for (index, opcode) in acir.opcodes.iter().enumerate() { + match opcode { + Opcode::Arithmetic(arith_expr) => { + let len = intermediate_variables.len(); + + let arith_expr = transformer.transform( + arith_expr.clone(), + &mut intermediate_variables, + &mut next_witness_index, + ); + + // Update next_witness counter + next_witness_index += (intermediate_variables.len() - len) as u32; + let mut new_opcodes = Vec::new(); + for (g, (norm, w)) in intermediate_variables.iter().skip(len) { + // de-normalize + let mut intermediate_opcode = g * *norm; + // constrain the intermediate opcode to the intermediate variable + intermediate_opcode.linear_combinations.push((-FieldElement::one(), *w)); + intermediate_opcode.sort(); + new_opcodes.push(intermediate_opcode); + } + new_opcodes.push(arith_expr); + for opcode in new_opcodes { + new_acir_opcode_positions.push(acir_opcode_positions[index]); + transformed_opcodes.push(Opcode::Arithmetic(opcode)); + } + } + Opcode::BlackBoxFuncCall(func) => { + match func { + acir::circuit::opcodes::BlackBoxFuncCall::AND { output, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::XOR { output, .. } => { + transformer.mark_solvable(*output); + } + acir::circuit::opcodes::BlackBoxFuncCall::RANGE { .. } => (), + acir::circuit::opcodes::BlackBoxFuncCall::SHA256 { outputs, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::Keccak256 { outputs, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::Keccak256VariableLength { + outputs, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::RecursiveAggregation { + output_aggregation_object: outputs, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::Blake2s { outputs, .. } => { + for witness in outputs { + transformer.mark_solvable(*witness); + } + } + acir::circuit::opcodes::BlackBoxFuncCall::FixedBaseScalarMul { + outputs, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::Pedersen { outputs, .. } => { + transformer.mark_solvable(outputs.0); + transformer.mark_solvable(outputs.1); + } + acir::circuit::opcodes::BlackBoxFuncCall::HashToField128Security { + output, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256k1 { output, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256r1 { output, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::SchnorrVerify { output, .. } => { + transformer.mark_solvable(*output); + } + } + + new_acir_opcode_positions.push(acir_opcode_positions[index]); + transformed_opcodes.push(opcode.clone()); + } + Opcode::Directive(directive) => { + match directive { + Directive::Quotient(quotient_directive) => { + transformer.mark_solvable(quotient_directive.q); + transformer.mark_solvable(quotient_directive.r); + } + Directive::ToLeRadix { b, .. } => { + for witness in b { + transformer.mark_solvable(*witness); + } + } + Directive::PermutationSort { bits, .. } => { + for witness in bits { + transformer.mark_solvable(*witness); + } + } + } + new_acir_opcode_positions.push(acir_opcode_positions[index]); + transformed_opcodes.push(opcode.clone()); + } + Opcode::MemoryInit { .. } => { + // `MemoryInit` does not write values to the `WitnessMap` + new_acir_opcode_positions.push(acir_opcode_positions[index]); + transformed_opcodes.push(opcode.clone()); + } + Opcode::MemoryOp { op, .. } => { + for (_, witness1, witness2) in &op.value.mul_terms { + transformer.mark_solvable(*witness1); + transformer.mark_solvable(*witness2); + } + for (_, witness) in &op.value.linear_combinations { + transformer.mark_solvable(*witness); + } + new_acir_opcode_positions.push(acir_opcode_positions[index]); + transformed_opcodes.push(opcode.clone()); + } + Opcode::Brillig(brillig) => { + for output in &brillig.outputs { + match output { + BrilligOutputs::Simple(w) => transformer.mark_solvable(*w), + BrilligOutputs::Array(v) => { + for witness in v { + transformer.mark_solvable(*witness); + } + } + } + } + new_acir_opcode_positions.push(acir_opcode_positions[index]); + transformed_opcodes.push(opcode.clone()); + } + } + } + + let current_witness_index = next_witness_index - 1; + + let transformation_map = + AcirTransformationMap { acir_opcode_positions: new_acir_opcode_positions }; + + let acir = Circuit { + current_witness_index, + opcodes: transformed_opcodes, + // The optimizer does not add new public inputs + private_parameters: acir.private_parameters, + public_parameters: acir.public_parameters, + return_values: acir.return_values, + assert_messages: transform_assert_messages(acir.assert_messages, &transformation_map), + }; + + Ok((acir, transformation_map)) +} diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index c9e9d95f4da..47e5f0492a5 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -113,9 +113,13 @@ pub fn create_circuit( .map(|(index, locations)| (index, locations.into_iter().collect())) .collect(); - let debug_info = DebugInfo::new(locations); + let mut debug_info = DebugInfo::new(locations); - Ok((circuit, debug_info, abi)) + // Perform any ACIR-level optimizations + let (optimized_ciruit, transformation_map) = acvm::compiler::optimize(circuit); + debug_info.update_acir(transformation_map); + + Ok((optimized_ciruit, debug_info, abi)) } // This is just a convenience object to bundle the ssa with `print_ssa_passes` for debug printing.