From 55b6ba28938a8d89a4255607a61243cf13391665 Mon Sep 17 00:00:00 2001 From: Facundo Date: Wed, 28 Aug 2024 15:56:51 +0100 Subject: [PATCH] fix(bb-prover): create structure for AVM vk (#8233) Apologies for duplicating code! I tried putting a generic on the "base" classes, but (1) generics don't play well with static methods (e.g., fromBuffer) and (2) you still need to pass the value for the VK size (on top of the type). I think most of this duplication can be avoided if you just accept some type unsafety and save things as `Fr[]` instead of tuples with size. PS: There might be still work to do to align the "num public inputs" etc indices, and the vk hash. --- .../src/barretenberg/vm/aztec_constants.hpp | 1 + .../crates/types/src/constants.nr | 6 + .../bb-prover/src/avm_proving.test.ts | 17 +-- .../bb-prover/src/prover/bb_prover.ts | 17 +-- .../bb-prover/src/test/test_circuit_prover.ts | 7 +- .../verification_key/verification_key_data.ts | 24 +++- .../src/interfaces/proving-job.ts | 7 +- .../src/interfaces/server_circuit_prover.ts | 8 +- yarn-project/circuits.js/src/constants.gen.ts | 1 + .../src/structs/verification_key.ts | 119 +++++++++++++++++- .../src/prover-agent/memory-proving-queue.ts | 8 +- 11 files changed, 182 insertions(+), 33 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp b/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp index 09119cec699..58c800ec1f5 100644 --- a/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp @@ -33,6 +33,7 @@ #define HEADER_LENGTH 24 #define PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH 691 #define PUBLIC_CONTEXT_INPUTS_LENGTH 42 +#define AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS 66 #define SENDER_SELECTOR 0 #define ADDRESS_SELECTOR 1 #define STORAGE_ADDRESS_SELECTOR 1 diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr b/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr index ebb231a4895..2367626358b 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/constants.nr @@ -290,6 +290,12 @@ global NESTED_RECURSIVE_PROOF_LENGTH = 439; global TUBE_PROOF_LENGTH = RECURSIVE_PROOF_LENGTH; // in the future these can differ global VERIFICATION_KEY_LENGTH_IN_FIELDS = 128; +// VK is composed of +// - circuit size encoded as a fr field element (32 bytes) +// - num of inputs encoded as a fr field element (32 bytes) +// - 16 affine elements (curve base field fq) encoded as fr elements takes (16 * 4 * 32 bytes) +// 16 above refers to the constant AvmFlavor::NUM_PRECOMPUTED_ENTITIES +global AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS = 2 + 16 * 4; /** * Enumerate the hash_indices which are used for pedersen hashing. diff --git a/yarn-project/bb-prover/src/avm_proving.test.ts b/yarn-project/bb-prover/src/avm_proving.test.ts index 28a5122618c..0ce98cc420a 100644 --- a/yarn-project/bb-prover/src/avm_proving.test.ts +++ b/yarn-project/bb-prover/src/avm_proving.test.ts @@ -1,5 +1,6 @@ import { AvmCircuitInputs, + AvmVerificationKeyData, AztecAddress, ContractStorageRead, ContractStorageUpdateRequest, @@ -49,7 +50,7 @@ import path from 'path'; import { PublicSideEffectTrace } from '../../simulator/src/public/side_effect_trace.js'; import { SerializableContractInstance } from '../../types/src/contracts/contract_instance.js'; import { type BBSuccess, BB_RESULT, generateAvmProof, verifyAvmProof } from './bb/execute.js'; -import { extractVkData } from './verification_key/verification_key_data.js'; +import { extractAvmVkData } from './verification_key/verification_key_data.js'; const TIMEOUT = 60_000; const TIMESTAMP = new Fr(99833); @@ -279,18 +280,10 @@ const proveAndVerifyAvmTestContract = async ( const proofRes = await generateAvmProof(bbPath, bbWorkingDirectory, avmCircuitInputs, logger); expect(proofRes.status).toEqual(BB_RESULT.SUCCESS); - // Then we test VK extraction. + // Then we test VK extraction and serialization. const succeededRes = proofRes as BBSuccess; - const verificationKey = await extractVkData(succeededRes.vkPath!); - - // VK is composed of - // - circuit size encoded as a fr field element (32 bytes) - // - num of inputs encoded as a fr field element (32 bytes) - // - 16 affine elements (curve base field fq) encoded as fr elements takes (16 * 4 * 32 bytes) - // 16 above refers to the constant AvmFlavor::NUM_PRECOMPUTED_ENTITIES - // Total number of bytes = 2112 - const NUM_PRECOMPUTED_ENTITIES = 16; - expect(verificationKey.keyAsBytes).toHaveLength(NUM_PRECOMPUTED_ENTITIES * 4 * 32 + 2 * 32); + const vkData = await extractAvmVkData(succeededRes.vkPath!); + AvmVerificationKeyData.fromBuffer(vkData.toBuffer()); // Then we verify. const rawVkPath = path.join(succeededRes.vkPath!, 'vk'); diff --git a/yarn-project/bb-prover/src/prover/bb_prover.ts b/yarn-project/bb-prover/src/prover/bb_prover.ts index dcd09631a13..2c38db4869e 100644 --- a/yarn-project/bb-prover/src/prover/bb_prover.ts +++ b/yarn-project/bb-prover/src/prover/bb_prover.ts @@ -1,6 +1,6 @@ /* eslint-disable require-await */ import { - type ProofAndVerificationKey, + type AvmProofAndVerificationKey, type PublicInputsAndRecursiveProof, type PublicKernelNonTailRequest, type PublicKernelTailRequest, @@ -11,6 +11,7 @@ import { type CircuitProvingStats, type CircuitWitnessGenerationStats } from '@a import { AGGREGATION_OBJECT_LENGTH, type AvmCircuitInputs, + type AvmVerificationKeyData, type BaseOrMergeRollupPublicInputs, type BaseParityInputs, type BaseRollupInputs, @@ -94,7 +95,7 @@ import type { ACVMConfig, BBConfig } from '../config.js'; import { ProverInstrumentation } from '../instrumentation.js'; import { PublicKernelArtifactMapping } from '../mappings/mappings.js'; import { mapProtocolArtifactNameToCircuitName } from '../stats.js'; -import { extractVkData } from '../verification_key/verification_key_data.js'; +import { extractAvmVkData, extractVkData } from '../verification_key/verification_key_data.js'; const logger = createDebugLogger('aztec:bb-prover'); @@ -202,7 +203,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { @trackSpan('BBNativeRollupProver.getAvmProof', inputs => ({ [Attributes.APP_CIRCUIT_NAME]: inputs.functionName, })) - public async getAvmProof(inputs: AvmCircuitInputs): Promise { + public async getAvmProof(inputs: AvmCircuitInputs): Promise { const proofAndVk = await this.createAvmProof(inputs); await this.verifyAvmProof(proofAndVk.proof, proofAndVk.verificationKey); return proofAndVk; @@ -626,14 +627,14 @@ export class BBNativeRollupProver implements ServerCircuitProver { return provingResult; } - private async createAvmProof(input: AvmCircuitInputs): Promise { - const operation = async (bbWorkingDirectory: string): Promise => { + private async createAvmProof(input: AvmCircuitInputs): Promise { + const operation = async (bbWorkingDirectory: string): Promise => { const provingResult = await this.generateAvmProofWithBB(input, bbWorkingDirectory); const rawProof = await fs.readFile(provingResult.proofPath!); // TODO(https://github.com/AztecProtocol/aztec-packages/issues/6773): this VK data format is wrong. // In particular, the number of public inputs, etc will be wrong. - const verificationKey = await extractVkData(provingResult.vkPath!); + const verificationKey = await extractAvmVkData(provingResult.vkPath!); const proof = new Proof(rawProof, verificationKey.numPublicInputs); const circuitType = 'avm-circuit' as const; @@ -765,7 +766,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { return await this.verifyWithKey(verificationKey, proof); } - public async verifyAvmProof(proof: Proof, verificationKey: VerificationKeyData) { + public async verifyAvmProof(proof: Proof, verificationKey: AvmVerificationKeyData) { return await this.verifyWithKeyInternal(proof, verificationKey, verifyAvmProof); } @@ -775,7 +776,7 @@ export class BBNativeRollupProver implements ServerCircuitProver { private async verifyWithKeyInternal( proof: Proof, - verificationKey: VerificationKeyData, + verificationKey: { keyAsBytes: Buffer }, verificationFunction: VerificationFunction, ) { const operation = async (bbWorkingDirectory: string) => { diff --git a/yarn-project/bb-prover/src/test/test_circuit_prover.ts b/yarn-project/bb-prover/src/test/test_circuit_prover.ts index ec8ed4ea331..63a5d7cf1ba 100644 --- a/yarn-project/bb-prover/src/test/test_circuit_prover.ts +++ b/yarn-project/bb-prover/src/test/test_circuit_prover.ts @@ -1,5 +1,5 @@ import { - type ProofAndVerificationKey, + type AvmProofAndVerificationKey, type PublicInputsAndRecursiveProof, type PublicKernelNonTailRequest, type PublicKernelTailRequest, @@ -8,6 +8,7 @@ import { } from '@aztec/circuit-types'; import { type AvmCircuitInputs, + AvmVerificationKeyData, type BaseOrMergeRollupPublicInputs, type BaseParityInputs, type BaseRollupInputs, @@ -475,12 +476,12 @@ export class TestCircuitProver implements ServerCircuitProver { ); } - public async getAvmProof(_inputs: AvmCircuitInputs): Promise { + public async getAvmProof(_inputs: AvmCircuitInputs): Promise { // We can't simulate the AVM because we don't have enough context to do so (e.g., DBs). // We just return an empty proof and VK data. this.logger.debug('Skipping AVM simulation in TestCircuitProver.'); await this.delay(); - return { proof: makeEmptyProof(), verificationKey: VerificationKeyData.makeFake() }; + return { proof: makeEmptyProof(), verificationKey: AvmVerificationKeyData.makeFake() }; } private async delay(): Promise { diff --git a/yarn-project/bb-prover/src/verification_key/verification_key_data.ts b/yarn-project/bb-prover/src/verification_key/verification_key_data.ts index b2fe18cc5e1..b5f4bacb1fa 100644 --- a/yarn-project/bb-prover/src/verification_key/verification_key_data.ts +++ b/yarn-project/bb-prover/src/verification_key/verification_key_data.ts @@ -1,11 +1,15 @@ import { + AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, + AvmVerificationKeyAsFields, + AvmVerificationKeyData, Fr, - type VERIFICATION_KEY_LENGTH_IN_FIELDS, + VERIFICATION_KEY_LENGTH_IN_FIELDS, VerificationKeyAsFields, VerificationKeyData, } from '@aztec/circuits.js'; import { type Tuple } from '@aztec/foundation/serialize'; +import { strict as assert } from 'assert'; import * as fs from 'fs/promises'; import * as path from 'path'; @@ -25,7 +29,25 @@ export async function extractVkData(vkDirectoryPath: string): Promise, vkHash); const vk = new VerificationKeyData(vkAsFields, rawBinary); return vk; } + +// TODO: This was adapted from the above function. A refactor might be needed. +export async function extractAvmVkData(vkDirectoryPath: string): Promise { + const [rawFields, rawBinary] = await Promise.all([ + fs.readFile(path.join(vkDirectoryPath, VK_FIELDS_FILENAME), { encoding: 'utf-8' }), + fs.readFile(path.join(vkDirectoryPath, VK_FILENAME)), + ]); + const fieldsJson = JSON.parse(rawFields); + const fields = fieldsJson.map(Fr.fromString); + // The first item is the hash, this is not part of the actual VK + // TODO: is the above actually the case? + const vkHash = fields[0]; + assert(fields.length === AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, 'Invalid AVM verification key length'); + const vkAsFields = new AvmVerificationKeyAsFields(fields, vkHash); + const vk = new AvmVerificationKeyData(vkAsFields, rawBinary); + return vk; +} diff --git a/yarn-project/circuit-types/src/interfaces/proving-job.ts b/yarn-project/circuit-types/src/interfaces/proving-job.ts index e496875bcf5..33127e4c535 100644 --- a/yarn-project/circuit-types/src/interfaces/proving-job.ts +++ b/yarn-project/circuit-types/src/interfaces/proving-job.ts @@ -1,5 +1,6 @@ import { type AvmCircuitInputs, + type AvmVerificationKeyData, type BaseOrMergeRollupPublicInputs, type BaseParityInputs, type BaseRollupInputs, @@ -25,9 +26,9 @@ import { import type { PublicKernelNonTailRequest, PublicKernelTailRequest } from '../tx/processed_tx.js'; -export type ProofAndVerificationKey = { +export type AvmProofAndVerificationKey = { proof: Proof; - verificationKey: VerificationKeyData; + verificationKey: AvmVerificationKeyData; }; export type PublicInputsAndRecursiveProof = { @@ -133,7 +134,7 @@ export type ProvingRequest = export type ProvingRequestPublicInputs = { [ProvingRequestType.PRIVATE_KERNEL_EMPTY]: PublicInputsAndRecursiveProof; - [ProvingRequestType.PUBLIC_VM]: ProofAndVerificationKey; + [ProvingRequestType.PUBLIC_VM]: AvmProofAndVerificationKey; [ProvingRequestType.PUBLIC_KERNEL_NON_TAIL]: PublicInputsAndRecursiveProof; [ProvingRequestType.PUBLIC_KERNEL_TAIL]: PublicInputsAndRecursiveProof; diff --git a/yarn-project/circuit-types/src/interfaces/server_circuit_prover.ts b/yarn-project/circuit-types/src/interfaces/server_circuit_prover.ts index 60fc09fb921..802cc5ea983 100644 --- a/yarn-project/circuit-types/src/interfaces/server_circuit_prover.ts +++ b/yarn-project/circuit-types/src/interfaces/server_circuit_prover.ts @@ -1,5 +1,5 @@ import { - type ProofAndVerificationKey, + type AvmProofAndVerificationKey, type PublicInputsAndRecursiveProof, type PublicInputsAndTubeProof, type PublicKernelNonTailRequest, @@ -149,7 +149,11 @@ export interface ServerCircuitProver { * Create a proof for the AVM circuit. * @param inputs - Inputs to the AVM circuit. */ - getAvmProof(inputs: AvmCircuitInputs, signal?: AbortSignal, epochNumber?: number): Promise; + getAvmProof( + inputs: AvmCircuitInputs, + signal?: AbortSignal, + epochNumber?: number, + ): Promise; } /** diff --git a/yarn-project/circuits.js/src/constants.gen.ts b/yarn-project/circuits.js/src/constants.gen.ts index 306aca91406..1b25acd8207 100644 --- a/yarn-project/circuits.js/src/constants.gen.ts +++ b/yarn-project/circuits.js/src/constants.gen.ts @@ -202,6 +202,7 @@ export const RECURSIVE_PROOF_LENGTH = 439; export const NESTED_RECURSIVE_PROOF_LENGTH = 439; export const TUBE_PROOF_LENGTH = 439; export const VERIFICATION_KEY_LENGTH_IN_FIELDS = 128; +export const AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS = 66; export const SENDER_SELECTOR = 0; export const ADDRESS_SELECTOR = 1; export const STORAGE_ADDRESS_SELECTOR = 1; diff --git a/yarn-project/circuits.js/src/structs/verification_key.ts b/yarn-project/circuits.js/src/structs/verification_key.ts index 2baffbf8039..331cc50ba10 100644 --- a/yarn-project/circuits.js/src/structs/verification_key.ts +++ b/yarn-project/circuits.js/src/structs/verification_key.ts @@ -3,7 +3,9 @@ import { times } from '@aztec/foundation/collection'; import { Fq, Fr } from '@aztec/foundation/fields'; import { BufferReader, type Tuple, serializeToBuffer } from '@aztec/foundation/serialize'; -import { VERIFICATION_KEY_LENGTH_IN_FIELDS } from '../constants.gen.js'; +import { strict as assert } from 'assert'; + +import { AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, VERIFICATION_KEY_LENGTH_IN_FIELDS } from '../constants.gen.js'; import { CircuitType } from './shared.js'; /** @@ -94,7 +96,7 @@ export class VerificationKeyAsFields { } public get isRecursive() { - return this.key[CIRCUIT_RECURSIVE_INDEX] == Fr.ONE; + return this.key[CIRCUIT_RECURSIVE_INDEX].equals(Fr.ONE); } /** @@ -135,6 +137,71 @@ export class VerificationKeyAsFields { } } +/** + * Provides a 'fields' representation of the AVM's verification key + */ +// TODO: This is a copy of the above, a refactor might be needed. +export class AvmVerificationKeyAsFields { + constructor(public key: Fr[], public hash: Fr) { + assert(this.key.length === AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, 'Invalid AVM key length'); + } + + public get numPublicInputs() { + return Number(this.key[CIRCUIT_PUBLIC_INPUTS_INDEX]); + } + + public get circuitSize() { + return Number(this.key[CIRCUIT_SIZE_INDEX]); + } + + public get isRecursive() { + return this.key[CIRCUIT_RECURSIVE_INDEX].equals(Fr.ONE); + } + + /** + * Serialize as a buffer. + * @returns The buffer. + */ + toBuffer() { + return serializeToBuffer(this.key, this.hash); + } + toFields() { + return [...this.key, this.hash]; + } + + /** + * Deserializes from a buffer or reader, corresponding to a write in cpp. + * @param buffer - Buffer to read from. + * @returns The AvmVerificationKeyAsFields. + */ + static fromBuffer(buffer: Buffer | BufferReader): AvmVerificationKeyAsFields { + const reader = BufferReader.asReader(buffer); + return new AvmVerificationKeyAsFields( + reader.readArray(AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, Fr), + reader.readObject(Fr), + ); + } + + /** + * Builds a fake verification key that should be accepted by circuits. + * @returns A fake verification key. + */ + static makeFake(seed = 1): AvmVerificationKeyAsFields { + return new AvmVerificationKeyAsFields( + makeTuple(AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, Fr.random, seed), + Fr.random(), + ); + } + + /** + * Builds an 'empty' verification key + * @returns An 'empty' verification key + */ + static makeEmpty(): AvmVerificationKeyAsFields { + return new AvmVerificationKeyAsFields(makeTuple(AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS, Fr.zero), Fr.zero()); + } +} + export class VerificationKey { constructor( /** @@ -257,3 +324,51 @@ export class VerificationKeyData { return VerificationKeyData.fromBuffer(this.toBuffer()); } } + +export class AvmVerificationKeyData { + constructor(public readonly keyAsFields: AvmVerificationKeyAsFields, public readonly keyAsBytes: Buffer) {} + + public get numPublicInputs() { + return this.keyAsFields.numPublicInputs; + } + + public get circuitSize() { + return this.keyAsFields.circuitSize; + } + + public get isRecursive() { + return this.keyAsFields.isRecursive; + } + + static makeFake(): AvmVerificationKeyData { + return new AvmVerificationKeyData(AvmVerificationKeyAsFields.makeFake(), VerificationKey.makeFake().toBuffer()); + } + + /** + * Serialize as a buffer. + * @returns The buffer. + */ + toBuffer() { + return serializeToBuffer(this.keyAsFields, this.keyAsBytes.length, this.keyAsBytes); + } + + toString() { + return this.toBuffer().toString('hex'); + } + + static fromBuffer(buffer: Buffer | BufferReader): AvmVerificationKeyData { + const reader = BufferReader.asReader(buffer); + const verificationKeyAsFields = reader.readObject(AvmVerificationKeyAsFields); + const length = reader.readNumber(); + const bytes = reader.readBytes(length); + return new AvmVerificationKeyData(verificationKeyAsFields, bytes); + } + + static fromString(str: string): AvmVerificationKeyData { + return AvmVerificationKeyData.fromBuffer(Buffer.from(str, 'hex')); + } + + public clone() { + return AvmVerificationKeyData.fromBuffer(this.toBuffer()); + } +} diff --git a/yarn-project/prover-client/src/prover-agent/memory-proving-queue.ts b/yarn-project/prover-client/src/prover-agent/memory-proving-queue.ts index be11c990a14..7dc8e10012a 100644 --- a/yarn-project/prover-client/src/prover-agent/memory-proving-queue.ts +++ b/yarn-project/prover-client/src/prover-agent/memory-proving-queue.ts @@ -1,5 +1,5 @@ import { - type ProofAndVerificationKey, + type AvmProofAndVerificationKey, type ProvingJob, type ProvingJobSource, type ProvingRequest, @@ -408,7 +408,11 @@ export class MemoryProvingQueue implements ServerCircuitProver, ProvingJobSource /** * Creates an AVM proof. */ - getAvmProof(inputs: AvmCircuitInputs, signal?: AbortSignal, epochNumber?: number): Promise { + getAvmProof( + inputs: AvmCircuitInputs, + signal?: AbortSignal, + epochNumber?: number, + ): Promise { return this.enqueue({ type: ProvingRequestType.PUBLIC_VM, inputs }, signal, epochNumber); }