Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NCC-ADN] Prevent ignoring proving tasks #2032

Closed
wants to merge 15 commits into from
Closed
7 changes: 4 additions & 3 deletions synthesizer/benches/kary_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ fn batch_prove(c: &mut Criterion) {
// Bench the proof construction.
for num_assignments in &[1, 2, 4, 8] {
// Construct the assignments.
let assignments =
[(proving_key.clone(), (0..*num_assignments).map(|_| assignment.clone()).collect::<Vec<_>>())];
let assignments = (0..*num_assignments).map(|_| assignment.clone()).collect::<Vec<_>>();
let keys_to_assignments = [(&proving_key, assignments.as_ref())].into_iter();

c.bench_function(&format!("KaryMerkleTree batch prove {num_assignments} assignments"), |b| {
b.iter(|| {
let _proof = ProvingKey::prove_batch("ProveKaryMerkleTree", &assignments, &mut rng).unwrap();
let _proof =
ProvingKey::prove_batch("ProveKaryMerkleTree", keys_to_assignments.clone(), &mut rng).unwrap();
})
});
}
Expand Down
105 changes: 105 additions & 0 deletions synthesizer/process/src/tests/test_execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2483,3 +2483,108 @@ function {function_name}:
assert_ne!(execution_1.peek().unwrap().id(), execution_2.peek().unwrap().id());
assert_ne!(execution_1.to_execution_id().unwrap(), execution_2.to_execution_id().unwrap());
}

#[test]
fn test_duplicate_function_call() {
// Initialize a new program.
let (string, program0) = Program::<CurrentNetwork>::parse(
r"
program zero.aleo;

function b:
input r0 as u8.private;
input r1 as u8.private;
add r0 r1 into r2;
output r2 as u8.private;",
)
.unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");

// Construct the process.
let mut process = crate::test_helpers::sample_process(&program0);

// Initialize another program.
let (string, program1) = Program::<CurrentNetwork>::parse(
r"
import zero.aleo;

program one.aleo;

function a:
input r0 as u8.private;
input r1 as u8.private;
call zero.aleo/b r0 r1 into r2;
call zero.aleo/b r0 r1 into r3;
output r2 as u8.private;
output r3 as u8.private;",
)
.unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");

// Add the program to the process.
process.add_program(&program1).unwrap();

// Initialize the RNG.
let rng = &mut TestRng::default();

// Initialize the caller.
let caller_private_key = PrivateKey::<CurrentNetwork>::new(rng).unwrap();

// Declare the function name.
let function_name = Identifier::from_str("a").unwrap();

// Declare the input value.
let r0 = Value::<CurrentNetwork>::from_str("1u8").unwrap();
let r1 = Value::<CurrentNetwork>::from_str("2u8").unwrap();

// Authorize the function call.
let authorization = process
.authorize::<CurrentAleo, _>(&caller_private_key, program1.id(), function_name, [r0, r1].iter(), rng)
.unwrap();
assert_eq!(authorization.len(), 3);
println!("\nAuthorize\n{:#?}\n\n", authorization.to_vec_deque());

let output = Value::<CurrentNetwork>::from_str("3u8").unwrap();

// Compute the output value.
let response = process.evaluate::<CurrentAleo>(authorization.replicate()).unwrap();
let candidate = response.outputs();
assert_eq!(2, candidate.len());
assert_eq!(output, candidate[0]);
assert_eq!(output, candidate[1]);

// Check again to make sure we didn't modify the authorization after calling `evaluate`.
assert_eq!(authorization.len(), 3);

// Execute the request.
let (response, mut trace) = process.execute::<CurrentAleo, _>(authorization, rng).unwrap();
let candidate = response.outputs();
assert_eq!(2, candidate.len());
assert_eq!(output, candidate[0]);
assert_eq!(output, candidate[1]);

// Construct the expected transition order.
let expected_order = [
(program0.id(), Identifier::<Testnet3>::from_str("b").unwrap()),
(program0.id(), Identifier::<Testnet3>::from_str("b").unwrap()),
(program1.id(), Identifier::from_str("a").unwrap()),
];

// Check the expected transition order.
for (transition, (expected_program_id, expected_function_name)) in
trace.transitions().iter().zip_eq(expected_order.iter())
{
assert_eq!(transition.program_id(), *expected_program_id);
assert_eq!(transition.function_name(), expected_function_name);
}

// Initialize a new block store.
let block_store = BlockStore::<CurrentNetwork, BlockMemory<_>>::open(None).unwrap();
// Prepare the trace.
trace.prepare(Query::from(block_store)).unwrap();
// Prove the execution.
let execution = trace.prove_execution::<CurrentAleo, _>("one", rng).unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the duplicate function call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention was above at call zero.aleo/b r0 r1 into r2;


// Verify the execution.
process.verify_execution(&execution).unwrap();
}
56 changes: 30 additions & 26 deletions synthesizer/process/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,20 @@ mod inclusion;
pub use inclusion::*;

use circuit::Assignment;
use console::{
network::prelude::*,
program::{InputID, Locator},
};
use console::{network::prelude::*, program::InputID};
use ledger_block::{Execution, Fee, Transition};
use ledger_query::QueryTrait;
use synthesizer_snark::{Proof, ProvingKey, VerifyingKey};

use once_cell::sync::OnceCell;
use std::collections::HashMap;
use std::collections::BTreeMap;

#[derive(Clone, Debug, Default)]
pub struct Trace<N: Network> {
/// The list of transitions.
transitions: Vec<Transition<N>>,
/// A map of locators to (proving key, assignments) pairs.
transition_tasks: HashMap<Locator<N>, (ProvingKey<N>, Vec<Assignment<N::Field>>)>,
transition_tasks: BTreeMap<ProvingKey<N>, Vec<Assignment<N::Field>>>,
/// A tracker for all inclusion tasks.
inclusion_tasks: Inclusion<N>,
/// A list of call metrics.
Expand All @@ -52,7 +49,7 @@ impl<N: Network> Trace<N> {
pub fn new() -> Self {
Self {
transitions: Vec::new(),
transition_tasks: HashMap::new(),
transition_tasks: BTreeMap::new(),
inclusion_tasks: Inclusion::new(),
inclusion_assignments: OnceCell::new(),
global_state_root: OnceCell::new(),
Expand Down Expand Up @@ -87,10 +84,8 @@ impl<N: Network> Trace<N> {
// Insert the transition into the inclusion tasks.
self.inclusion_tasks.insert_transition(input_ids, transition)?;

// Construct the locator.
let locator = Locator::new(*transition.program_id(), *transition.function_name());
// Insert the assignment (and proving key if the entry does not exist), for the specified locator.
self.transition_tasks.entry(locator).or_insert((proving_key, vec![])).1.push(assignment);
// Insert the assignment (and proving key if the entry does not exist).
self.transition_tasks.entry(proving_key).or_default().push(assignment);
// Insert the transition into the list.
self.transitions.push(transition.clone());
// Insert the call metrics into the list.
Expand Down Expand Up @@ -165,11 +160,9 @@ impl<N: Network> Trace<N> {
// Retrieve the global state root.
let global_state_root =
self.global_state_root.get().ok_or_else(|| anyhow!("Global state root has not been set"))?;
// Construct the proving tasks.
let proving_tasks = self.transition_tasks.values().cloned().collect();
// Compute the proof.
let (global_state_root, proof) =
Self::prove_batch::<A, R>(locator, proving_tasks, inclusion_assignments, *global_state_root, rng)?;
Self::prove_batch::<A, R>(locator, &self.transition_tasks, inclusion_assignments, *global_state_root, rng)?;
// Return the execution.
Execution::from(self.transitions.iter().cloned(), global_state_root, Some(proof))
}
Expand All @@ -193,12 +186,10 @@ impl<N: Network> Trace<N> {
self.global_state_root.get().ok_or_else(|| anyhow!("Global state root has not been set"))?;
// Retrieve the fee transition.
let fee_transition = &self.transitions[0];
// Construct the proving tasks.
let proving_tasks = self.transition_tasks.values().cloned().collect();
// Compute the proof.
let (global_state_root, proof) = Self::prove_batch::<A, R>(
"credits.aleo/fee (private or public)",
proving_tasks,
&self.transition_tasks,
inclusion_assignments,
*global_state_root,
rng,
Expand Down Expand Up @@ -258,7 +249,7 @@ impl<N: Network> Trace<N> {
/// Returns the global state root and proof for the given assignments.
fn prove_batch<A: circuit::Aleo<Network = N>, R: Rng + CryptoRng>(
locator: &str,
mut proving_tasks: Vec<(ProvingKey<N>, Vec<Assignment<N::Field>>)>,
proving_tasks: &BTreeMap<ProvingKey<N>, Vec<Assignment<N::Field>>>,
inclusion_assignments: &[InclusionAssignment<N>],
global_state_root: N::StateRoot,
rng: &mut R,
Expand All @@ -282,15 +273,27 @@ impl<N: Network> Trace<N> {
batch_inclusions.push(assignment.to_circuit_assignment::<A>()?);
}

if !batch_inclusions.is_empty() {
// Fetch the inclusion proving key.
let proving_key = ProvingKey::<N>::new(N::inclusion_proving_key().clone());
// Insert the inclusion proving key and assignments.
proving_tasks.push((proving_key, batch_inclusions));
// Fetch the inclusion proving key.
let inclusion_proving_key = ProvingKey::<N>::new(N::inclusion_proving_key().clone());

Comment on lines +276 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Fetch the inclusion proving key.
let inclusion_proving_key = ProvingKey::<N>::new(N::inclusion_proving_key().clone());
// Fetch the inclusion proving key.
let inclusion_proving_key = ProvingKey::<N>::new(N::inclusion_proving_key().clone());

This should be moved back into the if statement. Otherwise, we are needlessly loading this proving key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm defining it here to make sure it has the appropriate lifetime:

  • we are inserting it in line 284 by reference.
  • Which I'm doing because the proving keys are inserted by reference in line 265.
  • Which I'm doing because &self.transition_tasks is directly passed by reference in line 165.

Before this PR we made copies of the transition_tasks, which was fine for the provingkeys (which were Arcs) but not fine for the Assignments. Maybe there is a slightly cleaner way but it does not seem worth the time to refactor.

// Proving tasks should not contain inclusion_proving_key if we still have batch_inclusions to add
if !batch_inclusions.is_empty() && proving_tasks.contains_key(&inclusion_proving_key) {
bail!("The inclusion proving key (and its instances) have already been inserted");
}

// Collect optional inclusion iter
let inclusion_iter =
(!batch_inclusions.is_empty()).then_some((&inclusion_proving_key, batch_inclusions.as_slice()));

// Collect references to keys and assignments
let proving_tasks = proving_tasks.iter().map(|(p, a)| (p, a.as_slice()));

// Optionally insert the inclusion proving key and assignments.
let proving_tasks = proving_tasks.chain(inclusion_iter.into_iter());

// Compute the proof.
let proof = ProvingKey::prove_batch(locator, &proving_tasks, rng)?;
let proof = ProvingKey::prove_batch(locator, proving_tasks, rng)?;

// Return the global state root and proof.
Ok((global_state_root, proof))
}
Expand All @@ -315,8 +318,9 @@ impl<N: Network> Trace<N> {
}
// Verify the proof.
match VerifyingKey::verify_batch(locator, verifier_inputs, proof) {
true => Ok(()),
false => bail!("Failed to verify proof"),
Ok(true) => Ok(()),
Ok(false) => bail!("Failed to verify proof"),
Err(e) => bail!("Verifier failed - {e}"),
}
}
}
33 changes: 27 additions & 6 deletions synthesizer/snark/src/proving_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,20 @@ impl<N: Network> ProvingKey<N> {

/// Returns a proof for the given batch of proving keys and assignments.
#[allow(clippy::type_complexity)]
pub fn prove_batch<R: Rng + CryptoRng>(
pub fn prove_batch<'a, R: Rng + CryptoRng>(
locator: &str,
assignments: &[(ProvingKey<N>, Vec<circuit::Assignment<N::Field>>)],
assignments: impl Iterator<Item = (&'a ProvingKey<N>, &'a [circuit::Assignment<N::Field>])>,
rng: &mut R,
) -> Result<Proof<N>> {
#[cfg(feature = "aleo-cli")]
let timer = std::time::Instant::now();

// Prepare the instances.
let instances: BTreeMap<_, _> = assignments
.iter()
.map(|(proving_key, assignments)| (proving_key.deref(), assignments.as_slice()))
.collect();
let mut instances = BTreeMap::default();
for (proving_key, assignments) in assignments {
let previous_entry = instances.insert(proving_key.deref(), assignments);
ensure!(previous_entry.is_none(), "prove_batch found duplicate proving keys");
}

// Retrieve the proving parameters.
let universal_prover = N::varuna_universal_prover();
Expand All @@ -91,3 +92,23 @@ impl<N: Network> Deref for ProvingKey<N> {
&self.proving_key
}
}

impl<N: Network> Eq for ProvingKey<N> {}

impl<N: Network> PartialEq for ProvingKey<N> {
fn eq(&self, other: &Self) -> bool {
self.deref() == other.deref()
}
}
ljedrz marked this conversation as resolved.
Show resolved Hide resolved

impl<N: Network> Ord for ProvingKey<N> {
fn cmp(&self, other: &Self) -> Ordering {
self.deref().cmp(other.deref())
ljedrz marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl<N: Network> PartialOrd for ProvingKey<N> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
Comment on lines +98 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't these be auto-derived?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error[E0277]: the trait bound N: Ord is not satisfied
--> synthesizer/process/src/trace/mod.rs:89:31
|
89 | self.transition_tasks.entry(proving_key).or_default().push(assignment);
| ^^^^^ the trait Ord is not implemented for N
|
= note: required for ProvingKey<N> to implement Ord

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very odd, I'll give it a look tomorrow as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would do the job: https://github.com/AleoHQ/snarkVM/pull/2231

24 changes: 19 additions & 5 deletions synthesizer/snark/src/verifying_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,19 @@ impl<N: Network> VerifyingKey<N> {

/// Returns `true` if the batch proof is valid for the given public inputs.
#[allow(clippy::type_complexity)]
pub fn verify_batch(locator: &str, inputs: Vec<(VerifyingKey<N>, Vec<Vec<N::Field>>)>, proof: &Proof<N>) -> bool {
pub fn verify_batch(
locator: &str,
inputs: Vec<(VerifyingKey<N>, Vec<Vec<N::Field>>)>,
proof: &Proof<N>,
) -> Result<bool> {
#[cfg(feature = "aleo-cli")]
let timer = std::time::Instant::now();

// Convert the instances.
let expected_len = inputs.len();
let keys_to_inputs: BTreeMap<_, _> =
inputs.iter().map(|(verifying_key, inputs)| (verifying_key.deref(), inputs.as_slice())).collect();
ensure!(keys_to_inputs.len() == expected_len, "Found duplicate verifying keys");

// Retrieve the verification parameters.
let universal_verifier = N::varuna_universal_verifier();
Expand All @@ -76,14 +82,22 @@ impl<N: Network> VerifyingKey<N> {
// Verify the batch proof.
match Varuna::<N>::verify_batch(universal_verifier, fiat_shamir, &keys_to_inputs, proof) {
Ok(is_valid) => {
#[cfg(feature = "aleo-cli")]
println!("{}", format!(" • Verified '{locator}' (in {} ms)", timer.elapsed().as_millis()).dimmed());
is_valid
if is_valid {
#[cfg(feature = "aleo-cli")]
println!("{}", format!(" • Verified '{locator}' (in {} ms)", timer.elapsed().as_millis()).dimmed());
} else {
#[cfg(feature = "aleo-cli")]
println!(
"{}",
format!(" • Verification failed '{locator}' (in {} ms)", timer.elapsed().as_millis()).dimmed()
);
}
Ok(is_valid)
}
Err(error) => {
#[cfg(feature = "aleo-cli")]
println!("{}", format!(" • Verifier failed: {error}").dimmed());
false
bail!(error)
}
}
}
Expand Down