From ff5a309e90e7800b614d39b74ed26e30b63158fd Mon Sep 17 00:00:00 2001 From: fcarreiro Date: Tue, 11 Jun 2024 19:24:29 +0000 Subject: [PATCH] Support specifying column index in public input declaration. --- asm_to_pil/src/vm_to_constrained.rs | 10 +++--- ast/src/analyzed/display.rs | 3 ++ ast/src/analyzed/mod.rs | 2 ++ ast/src/analyzed/visitor.rs | 2 ++ ast/src/parsed/display.rs | 9 ++++- ast/src/parsed/mod.rs | 5 ++- ast/src/parsed/visitor.rs | 2 ++ bberg/src/bberg_codegen.rs | 2 +- bberg/src/permutation_builder.rs | 2 +- bberg/src/relation_builder.rs | 8 +++-- bberg/src/verifier_builder.rs | 20 +++++------ bberg/src/vm_builder.rs | 45 +++++++++++++------------ executor/src/constant_evaluator/mod.rs | 1 + parser/src/powdr.lalrpop | 8 ++--- pil_analyzer/src/statement_processor.rs | 34 ++++++++++--------- pilopt/src/lib.rs | 1 + 16 files changed, 91 insertions(+), 63 deletions(-) diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index c5f837e45..eb7e4d32c 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -313,7 +313,7 @@ impl ASMPILConverter { ty, }, ); - self.pil.push(witness_column(start, name, None, false)); + self.pil.push(witness_column(start, name, None, None)); } fn handle_instruction_def( @@ -846,7 +846,7 @@ impl ASMPILConverter { ), ) }); - witness_column(0, free_value, prover_query, false) + witness_column(0, free_value, prover_query, None) }) .collect::>(); self.pil.extend(free_value_pil); @@ -877,7 +877,7 @@ impl ASMPILConverter { /// Creates a pair of witness and fixed column and matches them in the lookup. fn create_witness_fixed_pair(&mut self, start: usize, name: &str) { let fixed_name = format!("p_{name}"); - self.pil.push(witness_column(start, name, None, false)); + self.pil.push(witness_column(start, name, None, None)); self.line_lookup .push((name.to_string(), fixed_name.clone())); self.rom_constant_names.push(fixed_name); @@ -1082,7 +1082,7 @@ fn witness_column, T>( start: usize, name: S, def: Option>, - is_public: bool, + public_info: Option, ) -> PilStatement { PilStatement::PolynomialCommitDeclaration( start, @@ -1091,7 +1091,7 @@ fn witness_column, T>( array_size: None, }], def, - is_public, + public_info, ) } diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index c5194b499..7d9232415 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -96,6 +96,9 @@ impl Display for Analyzed { impl Display for FunctionValueDefinition { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { + FunctionValueDefinition::Number(n) => { + write!(f, "{}", n) + } FunctionValueDefinition::Array(items) => { write!(f, " = {}", items.iter().format(" + ")) } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 4fda5bbfb..3ac47b241 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -290,6 +290,7 @@ impl Analyzed { .flat_map(|e| e.pattern.iter_mut()) .for_each(|e| e.post_visit_expressions_mut(f)), Some(FunctionValueDefinition::Expression(e)) => e.post_visit_expressions_mut(f), + Some(FunctionValueDefinition::Number(_)) => {} None => {} }); } @@ -446,6 +447,7 @@ pub enum FunctionValueDefinition { Array(Vec>), Query(Expression), Expression(Expression), + Number(usize), } /// An array of elements that might be repeated. diff --git a/ast/src/analyzed/visitor.rs b/ast/src/analyzed/visitor.rs index dc17e79f0..04f3faa50 100644 --- a/ast/src/analyzed/visitor.rs +++ b/ast/src/analyzed/visitor.rs @@ -93,6 +93,7 @@ impl ExpressionVisitable> for FunctionValueDefinition { .iter_mut() .flat_map(|a| a.pattern.iter_mut()) .try_for_each(move |item| item.visit_expressions_mut(f, o)), + FunctionValueDefinition::Number(_) => ControlFlow::Continue(()), } } @@ -108,6 +109,7 @@ impl ExpressionVisitable> for FunctionValueDefinition { .iter() .flat_map(|a| a.pattern().iter()) .try_for_each(move |item| item.visit_expressions(f, o)), + FunctionValueDefinition::Number(_) => ControlFlow::Continue(()), } } } diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 8aa8f803d..dddfbbca6 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -379,7 +379,11 @@ impl Display for PilStatement { write!( f, "pol commit {}{}{};", - if *public { "public " } else { " " }, + if let Some(n) = public { + format!("public : {n}") + } else { + " ".to_string() + }, names.iter().format(", "), value.as_ref().map(|v| format!("{v}")).unwrap_or_default(), ) @@ -431,6 +435,9 @@ impl Display for ArrayExpression { impl Display for FunctionDefinition { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { + FunctionDefinition::Number(n) => { + write!(f, "{n}") + } FunctionDefinition::Array(array_expression) => { write!(f, " = {array_expression}") } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 671ba865e..894cc8a45 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -37,7 +37,7 @@ pub enum PilStatement { usize, Vec>, Option>, - /*public=*/ bool, + Option, ), PolynomialIdentity(usize, Option, Expression), PlookupIdentity( @@ -82,6 +82,7 @@ pub enum Expression { Reference(Ref), PublicReference(String), Number(T), + // LiteralNumber(usize), String(String), Tuple(Vec>), LambdaExpression(LambdaExpression), @@ -270,6 +271,8 @@ pub enum FunctionDefinition { Query(Vec, Expression), /// Generic expression Expression(Expression), + /// Constant for public inputs + Number(usize), } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index 52811b162..7e032a01e 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -304,6 +304,7 @@ impl ExpressionVisitable> for FunctionDefinition { FunctionDefinition::Query(_, e) => e.visit_expressions_mut(f, o), FunctionDefinition::Array(ae) => ae.visit_expressions_mut(f, o), FunctionDefinition::Expression(e) => e.visit_expressions_mut(f, o), + FunctionDefinition::Number(_) => ControlFlow::Continue(()), } } @@ -315,6 +316,7 @@ impl ExpressionVisitable> for FunctionDefinition { FunctionDefinition::Query(_, e) => e.visit_expressions(f, o), FunctionDefinition::Array(ae) => ae.visit_expressions(f, o), FunctionDefinition::Expression(e) => e.visit_expressions(f, o), + FunctionDefinition::Number(_) => ControlFlow::Continue(()), } } } diff --git a/bberg/src/bberg_codegen.rs b/bberg/src/bberg_codegen.rs index 0e4a09a9f..20813d645 100644 --- a/bberg/src/bberg_codegen.rs +++ b/bberg/src/bberg_codegen.rs @@ -22,7 +22,7 @@ impl BBergCodegen { } pub fn new_from_setup(_input: &mut impl io::Read) -> Result { - println!("warning bberg: new_from_setup not implemented"); + log::warn!("warning bberg: new_from_setup not implemented"); Ok(Self {}) } diff --git a/bberg/src/permutation_builder.rs b/bberg/src/permutation_builder.rs index 6cc00db96..caf9d7a23 100644 --- a/bberg/src/permutation_builder.rs +++ b/bberg/src/permutation_builder.rs @@ -120,7 +120,7 @@ fn permutation_settings_includes() -> &'static str { } fn create_permutation_settings_file(permutation: &Permutation) -> String { - println!("Permutation: {:?}", permutation); + log::trace!("Permutation: {:?}", permutation); let columns_per_set = permutation.left.cols.len(); // TODO(md): In the future we will need to condense off the back of this - combining those with the same inverse column let permutation_name = permutation diff --git a/bberg/src/relation_builder.rs b/bberg/src/relation_builder.rs index 90a8ad03d..a20dff053 100644 --- a/bberg/src/relation_builder.rs +++ b/bberg/src/relation_builder.rs @@ -306,7 +306,7 @@ fn create_identity( if let Some(expr) = &expression.selector { let x = craft_expression(expr, collected_cols, collected_public_identities); - println!("{:?}", x); + log::trace!("expression {:?}", x); Some(x) } else { None @@ -464,8 +464,10 @@ pub(crate) fn create_identities( // Print a warning to the user about usage of public identities if !collected_public_identities.is_empty() { - println!("Public Identities are not supported yet in codegen, however some were collected"); - println!("Public Identities: {:?}", collected_public_identities); + log::warn!( + "Public Identities are not supported yet in codegen, however some were collected" + ); + log::warn!("Public Identities: {:?}", collected_public_identities); } let mut collected_cols: Vec = collected_cols.drain().collect(); diff --git a/bberg/src/verifier_builder.rs b/bberg/src/verifier_builder.rs index 904dd29ad..710b9cafa 100644 --- a/bberg/src/verifier_builder.rs +++ b/bberg/src/verifier_builder.rs @@ -9,10 +9,10 @@ pub trait VerifierBuilder { name: &str, witness: &[String], inverses: &[String], - public_cols: &[String], + public_cols: &[(String, usize)], ); - fn create_verifier_hpp(&mut self, name: &str, public_cols: &[String]); + fn create_verifier_hpp(&mut self, name: &str, public_cols: &[(String, usize)]); } impl VerifierBuilder for BBFiles { @@ -21,7 +21,7 @@ impl VerifierBuilder for BBFiles { name: &str, witness: &[String], inverses: &[String], - public_cols: &[String], + public_cols: &[(String, usize)], ) { let include_str = includes_cpp(&snake_case(name)); @@ -52,22 +52,22 @@ impl VerifierBuilder for BBFiles { format!("bool {name}Verifier::verify_proof(const HonkProof& proof)") }; - let public_inputs_column_transformation = |public_inputs_column_name: &String, i: usize| { - format!( + let public_inputs_column_transformation = + |public_inputs_column_name: &String, idx: usize| { + format!( " - FF {public_inputs_column_name}_evaluation = evaluate_public_input_column(public_inputs[{i}], circuit_size, multivariate_challenge); + FF {public_inputs_column_name}_evaluation = evaluate_public_input_column(public_inputs[{idx}], circuit_size, multivariate_challenge); if ({public_inputs_column_name}_evaluation != claimed_evaluations.{public_inputs_column_name}) {{ return false; }} " ) - }; + }; let (public_inputs_check, evaluate_public_inputs) = if has_public_input_columns { let inputs_check = public_cols .iter() - .enumerate() - .map(|(i, col_name)| public_inputs_column_transformation(col_name, i)) + .map(|(col_name, idx)| public_inputs_column_transformation(col_name, *idx)) .collect::(); let evaluate_public_inputs = format!( @@ -207,7 +207,7 @@ impl VerifierBuilder for BBFiles { ); } - fn create_verifier_hpp(&mut self, name: &str, public_cols: &[String]) { + fn create_verifier_hpp(&mut self, name: &str, public_cols: &[(String, usize)]) { let include_str = include_hpp(&snake_case(name)); // If there are public input columns, then the generated verifier must take them in as an argument for the verify_proof diff --git a/bberg/src/vm_builder.rs b/bberg/src/vm_builder.rs index 8aee9d02d..7cedc23af 100644 --- a/bberg/src/vm_builder.rs +++ b/bberg/src/vm_builder.rs @@ -1,5 +1,6 @@ use ast::analyzed::Analyzed; +use ast::analyzed::FunctionValueDefinition; use number::FieldElement; use crate::circuit_builder::CircuitBuilder; @@ -29,8 +30,6 @@ struct ColumnGroups { fixed: Vec, /// witness or commit columns in pil -> will be found in proof witness: Vec, - // public input columns, evaluations will be calculated within the verifier - public: Vec, /// witness or commit columns in pil, with out the inverse columns witnesses_without_inverses: Vec, /// fixed + witness columns without lookup inverses @@ -58,6 +57,20 @@ pub(crate) fn analyzed_to_cpp( witness: &[(String, Vec)], name: Option, ) { + // Extract public inputs information. + let mut public_inputs: Vec<(String, usize)> = analyzed + .definitions + .iter() + .filter_map(|(name, def)| { + if let (_, Some(FunctionValueDefinition::Number(idx))) = def { + Some((sanitize_name(name), *idx)) + } else { + None + } + }) + .collect(); + public_inputs.sort_by(|a, b| a.1.cmp(&b.1)); + // Sort fixed and witness to ensure consistent ordering let fixed = &sort_cols(fixed); let witness = &sort_cols(witness); @@ -91,7 +104,6 @@ pub(crate) fn analyzed_to_cpp( let ColumnGroups { fixed, witness, - public, witnesses_without_inverses, all_cols, all_cols_without_inverses, @@ -135,8 +147,13 @@ pub(crate) fn analyzed_to_cpp( bb_files.create_composer_hpp(file_name); // ----------------------- Create the Verifier files ----------------------- - bb_files.create_verifier_cpp(file_name, &witnesses_without_inverses, &inverses, &public); - bb_files.create_verifier_hpp(file_name, &public); + bb_files.create_verifier_cpp( + file_name, + &witnesses_without_inverses, + &inverses, + &public_inputs, + ); + bb_files.create_verifier_hpp(file_name, &public_inputs); // ----------------------- Create the Prover files ----------------------- bb_files.create_prover_cpp(file_name, &witnesses_without_inverses, &inverses); @@ -160,6 +177,8 @@ fn get_all_col_names( permutations: &[Permutation], lookups: &[Lookup], ) -> ColumnGroups { + log::info!("Getting all column names"); + // Transformations let sanitize = |(name, _): &(String, Vec)| sanitize_name(name).to_owned(); let append_shift = |name: &String| format!("{}_shift", *name); @@ -171,8 +190,6 @@ fn get_all_col_names( // Gather sanitized column names let fixed_names = collect_col(fixed, sanitize); let witness_names = collect_col(witness, sanitize); - let (witness_names, public_input_column_names) = extract_public_input_columns(witness_names); - let inverses = flatten(&[perm_inverses, lookup_inverses]); let witnesses_without_inverses = flatten(&[witness_names.clone(), lookup_counts.clone()]); let witnesses_with_inverses = flatten(&[witness_names, inverses.clone(), lookup_counts]); @@ -196,7 +213,6 @@ fn get_all_col_names( ColumnGroups { fixed: fixed_names, witness: witnesses_with_inverses, - public: public_input_column_names, all_cols_without_inverses, witnesses_without_inverses, all_cols, @@ -207,16 +223,3 @@ fn get_all_col_names( inverses, } } - -/// Extract public input columns -/// The compiler automatically suffixes the public input columns with "__is_public" -/// This function removes the suffix and collects the columns into their own container -pub fn extract_public_input_columns(witness_columns: Vec) -> (Vec, Vec) { - let witness_names: Vec = witness_columns.clone(); - let public_input_column_names: Vec = witness_columns - .into_iter() - .filter(|name| name.ends_with("__is_public")) - .collect(); - - (witness_names, public_input_column_names) -} diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index aa8045515..03bbdd194 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -36,6 +36,7 @@ fn generate_values( }; // TODO we should maybe pre-compute some symbols here. match body { + FunctionValueDefinition::Number(n) => vec![T::from(*n as u64)], FunctionValueDefinition::Expression(e) => (0..degree) .into_par_iter() .map(|i| { diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index a32d71818..bc6c6f5a8 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -1,7 +1,7 @@ use std::str::FromStr; use ast::parsed::{*, asm::*}; use number::{AbstractNumberType, FieldElement}; -use num_traits::Num; +use num_traits::{Num, ToPrimitive}; grammar where T: FieldElement; @@ -122,10 +122,10 @@ ArrayLiteralTerm: ArrayExpression = { } PolynomialCommitDeclaration: PilStatement = { - <@L> PolCol CommitWitness => PilStatement::PolynomialCommitDeclaration(<>, None, false), - <@L> PolCol "public" => PilStatement::PolynomialCommitDeclaration(<>, None, true), + <@L> PolCol CommitWitness => PilStatement::PolynomialCommitDeclaration(<>, None, None), + PolCol "public" "(" ")" => PilStatement::PolynomialCommitDeclaration(start, vec![name], None, Some(n.to_usize().unwrap())), PolCol CommitWitness "(" ")" "query" - => PilStatement::PolynomialCommitDeclaration(start, vec![name], Some(FunctionDefinition::Query(param, value)), false) + => PilStatement::PolynomialCommitDeclaration(start, vec![name], Some(FunctionDefinition::Query(param, value)), None) } PolynomialIdentity: PilStatement = { diff --git a/pil_analyzer/src/statement_processor.rs b/pil_analyzer/src/statement_processor.rs index 2e7e2a22d..d1bc5a22c 100644 --- a/pil_analyzer/src/statement_processor.rs +++ b/pil_analyzer/src/statement_processor.rs @@ -110,12 +110,7 @@ where self.handle_public_declaration(start, name, polynomial, array_index, index) } PilStatement::PolynomialConstantDeclaration(start, polynomials) => self - .handle_polynomial_declarations( - start, - polynomials, - PolynomialType::Constant, - false, - ), + .handle_polynomial_declarations(start, polynomials, PolynomialType::Constant, None), PilStatement::PolynomialConstantDefinition(start, name, definition) => self .handle_symbol_definition( start, @@ -124,13 +119,14 @@ where SymbolKind::Poly(PolynomialType::Constant), Some(definition), ), - PilStatement::PolynomialCommitDeclaration(start, polynomials, None, is_public) => self - .handle_polynomial_declarations( + PilStatement::PolynomialCommitDeclaration(start, polynomials, None, public_info) => { + self.handle_polynomial_declarations( start, polynomials, PolynomialType::Committed, - is_public, - ), + public_info, + ) + } PilStatement::PolynomialCommitDeclaration( start, mut polynomials, @@ -280,16 +276,21 @@ where start: usize, polynomials: Vec>, polynomial_type: PolynomialType, - is_public: bool, + public_info: Option, ) -> Vec> { + if public_info.is_some() { + assert!(polynomials.len() == 1); + } polynomials .into_iter() .flat_map(|PolynomialName { name, array_size }| { - // hack(https://github.com/AztecProtocol/aztec-packages/issues/6359): add an is_public modifier to the end of a committed polynomial - let name = if is_public { - format!("{name}__is_public") + let value = if let Some(idx) = public_info { + // let formatted = format!("{name}__public_input_{idx}"); + // println!("Formatted: {formatted}"); + // formatted + Some(FunctionDefinition::Number(idx)) } else { - name + None }; self.handle_symbol_definition( @@ -297,7 +298,7 @@ where name, array_size, SymbolKind::Poly(polynomial_type), - None, + value, ) }) .collect() @@ -330,6 +331,7 @@ where }; let value = value.map(|v| match v { + FunctionDefinition::Number(n) => FunctionValueDefinition::Number(n), FunctionDefinition::Expression(expr) => { assert!(!have_array_size); assert!(symbol_kind != SymbolKind::Poly(PolynomialType::Committed)); diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 090e4f9b4..6432ecca6 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -80,6 +80,7 @@ fn constant_value(function: &FunctionValueDefinition) -> Opt } FunctionValueDefinition::Query(_) => None, FunctionValueDefinition::Expression(_) => None, + FunctionValueDefinition::Number(_) => None, } }