Skip to content

Commit

Permalink
Merge pull request #69 from AztecProtocol/fc/public-inputs-with-index
Browse files Browse the repository at this point in the history
Support specifying column index in public input declaration
  • Loading branch information
IlyasRidhuan authored Jun 12, 2024
2 parents c439294 + ff5a309 commit 5ddbe75
Show file tree
Hide file tree
Showing 16 changed files with 91 additions and 63 deletions.
10 changes: 5 additions & 5 deletions asm_to_pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
ty,
},
);
self.pil.push(witness_column(start, name, None, false));
self.pil.push(witness_column(start, name, None, None));
}

fn handle_instruction_def(
Expand Down Expand Up @@ -846,7 +846,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
),
)
});
witness_column(0, free_value, prover_query, false)
witness_column(0, free_value, prover_query, None)
})
.collect::<Vec<_>>();
self.pil.extend(free_value_pil);
Expand Down Expand Up @@ -877,7 +877,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
/// Creates a pair of witness and fixed column and matches them in the lookup.
fn create_witness_fixed_pair(&mut self, start: usize, name: &str) {
let fixed_name = format!("p_{name}");
self.pil.push(witness_column(start, name, None, false));
self.pil.push(witness_column(start, name, None, None));
self.line_lookup
.push((name.to_string(), fixed_name.clone()));
self.rom_constant_names.push(fixed_name);
Expand Down Expand Up @@ -1082,7 +1082,7 @@ fn witness_column<S: Into<String>, T>(
start: usize,
name: S,
def: Option<FunctionDefinition<T>>,
is_public: bool,
public_info: Option<usize>,
) -> PilStatement<T> {
PilStatement::PolynomialCommitDeclaration(
start,
Expand All @@ -1091,7 +1091,7 @@ fn witness_column<S: Into<String>, T>(
array_size: None,
}],
def,
is_public,
public_info,
)
}

Expand Down
3 changes: 3 additions & 0 deletions ast/src/analyzed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ impl<T: Display> Display for Analyzed<T> {
impl<T: Display> Display for FunctionValueDefinition<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
FunctionValueDefinition::Number(n) => {
write!(f, "{}", n)
}
FunctionValueDefinition::Array(items) => {
write!(f, " = {}", items.iter().format(" + "))
}
Expand Down
2 changes: 2 additions & 0 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ impl<T> Analyzed<T> {
.flat_map(|e| e.pattern.iter_mut())
.for_each(|e| e.post_visit_expressions_mut(f)),
Some(FunctionValueDefinition::Expression(e)) => e.post_visit_expressions_mut(f),
Some(FunctionValueDefinition::Number(_)) => {}
None => {}
});
}
Expand Down Expand Up @@ -446,6 +447,7 @@ pub enum FunctionValueDefinition<T> {
Array(Vec<RepeatedArray<T>>),
Query(Expression<T>),
Expression(Expression<T>),
Number(usize),
}

/// An array of elements that might be repeated.
Expand Down
2 changes: 2 additions & 0 deletions ast/src/analyzed/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
.iter_mut()
.flat_map(|a| a.pattern.iter_mut())
.try_for_each(move |item| item.visit_expressions_mut(f, o)),
FunctionValueDefinition::Number(_) => ControlFlow::Continue(()),
}
}

Expand All @@ -108,6 +109,7 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
.iter()
.flat_map(|a| a.pattern().iter())
.try_for_each(move |item| item.visit_expressions(f, o)),
FunctionValueDefinition::Number(_) => ControlFlow::Continue(()),
}
}
}
9 changes: 8 additions & 1 deletion ast/src/parsed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,11 @@ impl<T: Display> Display for PilStatement<T> {
write!(
f,
"pol commit {}{}{};",
if *public { "public " } else { " " },
if let Some(n) = public {
format!("public : {n}")
} else {
" ".to_string()
},
names.iter().format(", "),
value.as_ref().map(|v| format!("{v}")).unwrap_or_default(),
)
Expand Down Expand Up @@ -431,6 +435,9 @@ impl<T: Display> Display for ArrayExpression<T> {
impl<T: Display> Display for FunctionDefinition<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
FunctionDefinition::Number(n) => {
write!(f, "{n}")
}
FunctionDefinition::Array(array_expression) => {
write!(f, " = {array_expression}")
}
Expand Down
5 changes: 4 additions & 1 deletion ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub enum PilStatement<T> {
usize,
Vec<PolynomialName<T>>,
Option<FunctionDefinition<T>>,
/*public=*/ bool,
Option<usize>,
),
PolynomialIdentity(usize, Option<String>, Expression<T>),
PlookupIdentity(
Expand Down Expand Up @@ -82,6 +82,7 @@ pub enum Expression<T, Ref = NamespacedPolynomialReference> {
Reference(Ref),
PublicReference(String),
Number(T),
// LiteralNumber(usize),
String(String),
Tuple(Vec<Expression<T, Ref>>),
LambdaExpression(LambdaExpression<T, Ref>),
Expand Down Expand Up @@ -270,6 +271,8 @@ pub enum FunctionDefinition<T> {
Query(Vec<String>, Expression<T>),
/// Generic expression
Expression(Expression<T>),
/// Constant for public inputs
Number(usize),
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
Expand Down
2 changes: 2 additions & 0 deletions ast/src/parsed/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionDefinition<T> {
FunctionDefinition::Query(_, e) => e.visit_expressions_mut(f, o),
FunctionDefinition::Array(ae) => ae.visit_expressions_mut(f, o),
FunctionDefinition::Expression(e) => e.visit_expressions_mut(f, o),
FunctionDefinition::Number(_) => ControlFlow::Continue(()),
}
}

Expand All @@ -315,6 +316,7 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionDefinition<T> {
FunctionDefinition::Query(_, e) => e.visit_expressions(f, o),
FunctionDefinition::Array(ae) => ae.visit_expressions(f, o),
FunctionDefinition::Expression(e) => e.visit_expressions(f, o),
FunctionDefinition::Number(_) => ControlFlow::Continue(()),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion bberg/src/bberg_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl BBergCodegen {
}

pub fn new_from_setup(_input: &mut impl io::Read) -> Result<Self, io::Error> {
println!("warning bberg: new_from_setup not implemented");
log::warn!("warning bberg: new_from_setup not implemented");
Ok(Self {})
}

Expand Down
2 changes: 1 addition & 1 deletion bberg/src/permutation_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ fn permutation_settings_includes() -> &'static str {
}

fn create_permutation_settings_file(permutation: &Permutation) -> String {
println!("Permutation: {:?}", permutation);
log::trace!("Permutation: {:?}", permutation);
let columns_per_set = permutation.left.cols.len();
// TODO(md): In the future we will need to condense off the back of this - combining those with the same inverse column
let permutation_name = permutation
Expand Down
8 changes: 5 additions & 3 deletions bberg/src/relation_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ fn create_identity<T: FieldElement>(

if let Some(expr) = &expression.selector {
let x = craft_expression(expr, collected_cols, collected_public_identities);
println!("{:?}", x);
log::trace!("expression {:?}", x);
Some(x)
} else {
None
Expand Down Expand Up @@ -464,8 +464,10 @@ pub(crate) fn create_identities<F: FieldElement>(

// Print a warning to the user about usage of public identities
if !collected_public_identities.is_empty() {
println!("Public Identities are not supported yet in codegen, however some were collected");
println!("Public Identities: {:?}", collected_public_identities);
log::warn!(
"Public Identities are not supported yet in codegen, however some were collected"
);
log::warn!("Public Identities: {:?}", collected_public_identities);
}

let mut collected_cols: Vec<String> = collected_cols.drain().collect();
Expand Down
20 changes: 10 additions & 10 deletions bberg/src/verifier_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ pub trait VerifierBuilder {
name: &str,
witness: &[String],
inverses: &[String],
public_cols: &[String],
public_cols: &[(String, usize)],
);

fn create_verifier_hpp(&mut self, name: &str, public_cols: &[String]);
fn create_verifier_hpp(&mut self, name: &str, public_cols: &[(String, usize)]);
}

impl VerifierBuilder for BBFiles {
Expand All @@ -21,7 +21,7 @@ impl VerifierBuilder for BBFiles {
name: &str,
witness: &[String],
inverses: &[String],
public_cols: &[String],
public_cols: &[(String, usize)],
) {
let include_str = includes_cpp(&snake_case(name));

Expand Down Expand Up @@ -52,22 +52,22 @@ impl VerifierBuilder for BBFiles {
format!("bool {name}Verifier::verify_proof(const HonkProof& proof)")
};

let public_inputs_column_transformation = |public_inputs_column_name: &String, i: usize| {
format!(
let public_inputs_column_transformation =
|public_inputs_column_name: &String, idx: usize| {
format!(
"
FF {public_inputs_column_name}_evaluation = evaluate_public_input_column(public_inputs[{i}], circuit_size, multivariate_challenge);
FF {public_inputs_column_name}_evaluation = evaluate_public_input_column(public_inputs[{idx}], circuit_size, multivariate_challenge);
if ({public_inputs_column_name}_evaluation != claimed_evaluations.{public_inputs_column_name}) {{
return false;
}}
"
)
};
};

let (public_inputs_check, evaluate_public_inputs) = if has_public_input_columns {
let inputs_check = public_cols
.iter()
.enumerate()
.map(|(i, col_name)| public_inputs_column_transformation(col_name, i))
.map(|(col_name, idx)| public_inputs_column_transformation(col_name, *idx))
.collect::<String>();

let evaluate_public_inputs = format!(
Expand Down Expand Up @@ -207,7 +207,7 @@ impl VerifierBuilder for BBFiles {
);
}

fn create_verifier_hpp(&mut self, name: &str, public_cols: &[String]) {
fn create_verifier_hpp(&mut self, name: &str, public_cols: &[(String, usize)]) {
let include_str = include_hpp(&snake_case(name));

// If there are public input columns, then the generated verifier must take them in as an argument for the verify_proof
Expand Down
45 changes: 24 additions & 21 deletions bberg/src/vm_builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use ast::analyzed::Analyzed;

use ast::analyzed::FunctionValueDefinition;
use number::FieldElement;

use crate::circuit_builder::CircuitBuilder;
Expand Down Expand Up @@ -29,8 +30,6 @@ struct ColumnGroups {
fixed: Vec<String>,
/// witness or commit columns in pil -> will be found in proof
witness: Vec<String>,
// public input columns, evaluations will be calculated within the verifier
public: Vec<String>,
/// witness or commit columns in pil, with out the inverse columns
witnesses_without_inverses: Vec<String>,
/// fixed + witness columns without lookup inverses
Expand Down Expand Up @@ -58,6 +57,20 @@ pub(crate) fn analyzed_to_cpp<F: FieldElement>(
witness: &[(String, Vec<F>)],
name: Option<String>,
) {
// Extract public inputs information.
let mut public_inputs: Vec<(String, usize)> = analyzed
.definitions
.iter()
.filter_map(|(name, def)| {
if let (_, Some(FunctionValueDefinition::Number(idx))) = def {
Some((sanitize_name(name), *idx))
} else {
None
}
})
.collect();
public_inputs.sort_by(|a, b| a.1.cmp(&b.1));

// Sort fixed and witness to ensure consistent ordering
let fixed = &sort_cols(fixed);
let witness = &sort_cols(witness);
Expand Down Expand Up @@ -91,7 +104,6 @@ pub(crate) fn analyzed_to_cpp<F: FieldElement>(
let ColumnGroups {
fixed,
witness,
public,
witnesses_without_inverses,
all_cols,
all_cols_without_inverses,
Expand Down Expand Up @@ -135,8 +147,13 @@ pub(crate) fn analyzed_to_cpp<F: FieldElement>(
bb_files.create_composer_hpp(file_name);

// ----------------------- Create the Verifier files -----------------------
bb_files.create_verifier_cpp(file_name, &witnesses_without_inverses, &inverses, &public);
bb_files.create_verifier_hpp(file_name, &public);
bb_files.create_verifier_cpp(
file_name,
&witnesses_without_inverses,
&inverses,
&public_inputs,
);
bb_files.create_verifier_hpp(file_name, &public_inputs);

// ----------------------- Create the Prover files -----------------------
bb_files.create_prover_cpp(file_name, &witnesses_without_inverses, &inverses);
Expand All @@ -160,6 +177,8 @@ fn get_all_col_names<F: FieldElement>(
permutations: &[Permutation],
lookups: &[Lookup],
) -> ColumnGroups {
log::info!("Getting all column names");

// Transformations
let sanitize = |(name, _): &(String, Vec<F>)| sanitize_name(name).to_owned();
let append_shift = |name: &String| format!("{}_shift", *name);
Expand All @@ -171,8 +190,6 @@ fn get_all_col_names<F: FieldElement>(
// Gather sanitized column names
let fixed_names = collect_col(fixed, sanitize);
let witness_names = collect_col(witness, sanitize);
let (witness_names, public_input_column_names) = extract_public_input_columns(witness_names);

let inverses = flatten(&[perm_inverses, lookup_inverses]);
let witnesses_without_inverses = flatten(&[witness_names.clone(), lookup_counts.clone()]);
let witnesses_with_inverses = flatten(&[witness_names, inverses.clone(), lookup_counts]);
Expand All @@ -196,7 +213,6 @@ fn get_all_col_names<F: FieldElement>(
ColumnGroups {
fixed: fixed_names,
witness: witnesses_with_inverses,
public: public_input_column_names,
all_cols_without_inverses,
witnesses_without_inverses,
all_cols,
Expand All @@ -207,16 +223,3 @@ fn get_all_col_names<F: FieldElement>(
inverses,
}
}

/// Extract public input columns
/// The compiler automatically suffixes the public input columns with "__is_public"
/// This function removes the suffix and collects the columns into their own container
pub fn extract_public_input_columns(witness_columns: Vec<String>) -> (Vec<String>, Vec<String>) {
let witness_names: Vec<String> = witness_columns.clone();
let public_input_column_names: Vec<String> = witness_columns
.into_iter()
.filter(|name| name.ends_with("__is_public"))
.collect();

(witness_names, public_input_column_names)
}
1 change: 1 addition & 0 deletions executor/src/constant_evaluator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ fn generate_values<T: FieldElement>(
};
// TODO we should maybe pre-compute some symbols here.
match body {
FunctionValueDefinition::Number(n) => vec![T::from(*n as u64)],
FunctionValueDefinition::Expression(e) => (0..degree)
.into_par_iter()
.map(|i| {
Expand Down
8 changes: 4 additions & 4 deletions parser/src/powdr.lalrpop
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::str::FromStr;
use ast::parsed::{*, asm::*};
use number::{AbstractNumberType, FieldElement};
use num_traits::Num;
use num_traits::{Num, ToPrimitive};

grammar<T> where T: FieldElement;

Expand Down Expand Up @@ -122,10 +122,10 @@ ArrayLiteralTerm: ArrayExpression<T> = {
}

PolynomialCommitDeclaration: PilStatement<T> = {
<@L> PolCol CommitWitness <PolynomialNameList> => PilStatement::PolynomialCommitDeclaration(<>, None, false),
<@L> PolCol "public" <PolynomialNameList> => PilStatement::PolynomialCommitDeclaration(<>, None, true),
<@L> PolCol CommitWitness <PolynomialNameList> => PilStatement::PolynomialCommitDeclaration(<>, None, None),
<start:@L> PolCol "public" "(" <n:Integer> ")" <name:PolynomialName> => PilStatement::PolynomialCommitDeclaration(start, vec![name], None, Some(n.to_usize().unwrap())),
<start:@L> PolCol CommitWitness <name:PolynomialName> "(" <param:ParameterList> ")" "query" <value:Expression>
=> PilStatement::PolynomialCommitDeclaration(start, vec![name], Some(FunctionDefinition::Query(param, value)), false)
=> PilStatement::PolynomialCommitDeclaration(start, vec![name], Some(FunctionDefinition::Query(param, value)), None)
}

PolynomialIdentity: PilStatement<T> = {
Expand Down
Loading

0 comments on commit 5ddbe75

Please sign in to comment.