Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(raiko): proof aggregation #347

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 56 additions & 1 deletion core/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use alloy_primitives::{Address, B256};
use clap::{Args, ValueEnum};
use raiko_lib::{
consts::VerifierType,
input::{BlobProofType, GuestInput, GuestOutput},
input::{
AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput,
},
primitives::eip4844::{calc_kzg_proof, commitment_to_version_hash, kzg_proof_to_bytes},
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError},
};
Expand Down Expand Up @@ -221,6 +223,47 @@ impl ProofType {
Ok(proof)
}

/// Run the prover driver depending on the proof type.
pub async fn aggregate_proofs(
&self,
input: AggregationGuestInput,
output: &AggregationGuestOutput,
config: &Value,
store: Option<&mut dyn IdWrite>,
) -> RaikoResult<Proof> {
let proof = match self {
ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store)
.await
.map_err(<ProverError as Into<RaikoError>>::into),
ProofType::Sp1 => {
#[cfg(feature = "sp1")]
return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sp1"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Risc0 => {
#[cfg(feature = "risc0")]
return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "risc0"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Sgx => {
#[cfg(feature = "sgx")]
return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sgx"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
}?;

Ok(proof)
}

pub async fn cancel_proof(
&self,
proof_key: ProofKey,
Expand Down Expand Up @@ -408,3 +451,15 @@ impl TryFrom<ProofRequestOpt> for ProofRequest {
})
}
}

#[serde_as]
#[derive(Clone, Debug, Serialize, Deserialize)]
/// A request for proof aggregation of multiple proofs.
pub struct AggregationRequest {
/// All the proofs to verify
pub proofs: Vec<Proof>,
/// The proof type.
pub proof_type: ProofType,
/// Additional prover params.
pub prover_args: HashMap<String, Value>,
}
68 changes: 57 additions & 11 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ mod tests {
use clap::ValueEnum;
use raiko_lib::{
consts::{Network, SupportedChainSpecs},
input::BlobProofType,
input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType},
primitives::B256,
prover::Proof,
};
use serde_json::{json, Value};
use std::{collections::HashMap, env};
Expand All @@ -237,7 +238,7 @@ mod tests {
ci == "1"
}

fn test_proof_params() -> HashMap<String, Value> {
fn test_proof_params(enable_aggregation: bool) -> HashMap<String, Value> {
let mut prover_args = HashMap::new();
prover_args.insert(
"native".to_string(),
Expand All @@ -251,7 +252,7 @@ mod tests {
"sp1".to_string(),
json! {
{
"recursion": "core",
"recursion": if enable_aggregation { "compressed" } else { "plonk" },
"prover": "mock",
"verify": true
}
Expand All @@ -273,8 +274,8 @@ mod tests {
json! {
{
"instance_id": 121,
"setup": true,
"bootstrap": true,
"setup": enable_aggregation,
"bootstrap": enable_aggregation,
"prove": true,
}
},
Expand All @@ -286,7 +287,7 @@ mod tests {
l1_chain_spec: ChainSpec,
taiko_chain_spec: ChainSpec,
proof_request: ProofRequest,
) {
) -> Proof {
let provider =
RpcBlockDataProvider::new(&taiko_chain_spec.rpc, proof_request.block_number - 1)
.expect("Could not create RpcBlockDataProvider");
Expand All @@ -296,10 +297,10 @@ mod tests {
.await
.expect("input generation failed");
let output = raiko.get_output(&input).expect("output generation failed");
let _proof = raiko
raiko
.prove(input, &output, None)
.await
.expect("proof generation failed");
.expect("proof generation failed")
}

#[tokio::test(flavor = "multi_thread")]
Expand All @@ -325,7 +326,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -362,7 +363,7 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
Expand Down Expand Up @@ -394,9 +395,54 @@ mod tests {
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(),
prover_args: test_proof_params(false),
};
prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;
}
}

#[tokio::test(flavor = "multi_thread")]
async fn test_prove_block_taiko_a7_aggregated() {
let proof_type = get_proof_type_from_env();
let l1_network = Network::Holesky.to_string();
let network = Network::TaikoA7.to_string();
// Give the CI an simpler block to test because it doesn't have enough memory.
// Unfortunately that also means that kzg is not getting fully verified by CI.
let block_number = if is_ci() { 105987 } else { 101368 };
let taiko_chain_spec = SupportedChainSpecs::default()
.get_chain_spec(&network)
.unwrap();
let l1_chain_spec = SupportedChainSpecs::default()
.get_chain_spec(&l1_network)
.unwrap();

let proof_request = ProofRequest {
block_number,
network,
graffiti: B256::ZERO,
prover: Address::ZERO,
l1_network,
proof_type,
blob_proof_type: BlobProofType::ProofOfEquivalence,
prover_args: test_proof_params(true),
};
let proof = prove_block(l1_chain_spec, taiko_chain_spec, proof_request).await;

let input = AggregationGuestInput {
proofs: vec![proof.clone(), proof],
};

let output = AggregationGuestOutput { hash: B256::ZERO };

let aggregated_proof = proof_type
.aggregate_proofs(
input,
&output,
&serde_json::to_value(&test_proof_params(false)).unwrap(),
None,
)
.await
.expect("proof aggregation failed");
println!("aggregated proof: {:?}", aggregated_proof);
}
}
15 changes: 12 additions & 3 deletions core/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,22 @@ impl Prover for NativeProver {
}

Ok(Proof {
proof: None,
quote: None,
kzg_proof: None,
..Default::default()
})
}

async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> {
Ok(())
}

async fn aggregate(
input: raiko_lib::input::AggregationGuestInput,
output: &raiko_lib::input::AggregationGuestOutput,
config: &ProverConfig,
store: Option<&mut dyn IdWrite>,
) -> ProverResult<Proof> {
Ok(Proof {
..Default::default()
})
}
}
2 changes: 2 additions & 0 deletions host/src/server/api/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ pub fn create_router() -> Router<ProverState> {
// Only add the concurrency limit to the proof route. We want to still be able to call
// healthchecks and metrics to have insight into the system.
.nest("/proof", proof::create_router())
// TODO: Separate task or try to get it into /proof somehow? Probably separate
.nest("/aggregate", proof::create_router())
.nest("/health", v1::health::create_router())
.nest("/metrics", v1::metrics::create_router())
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", docs.clone()))
Expand Down
40 changes: 39 additions & 1 deletion lib/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use serde_with::serde_as;

#[cfg(not(feature = "std"))]
use crate::no_std::*;
use crate::{consts::ChainSpec, primitives::mpt::MptNode, utils::zlib_compress_data};
use crate::{
consts::ChainSpec, primitives::mpt::MptNode, prover::Proof, utils::zlib_compress_data,
};

/// Represents the state of an account's storage.
/// The storage trie together with the used storage slots allow us to reconstruct all the
Expand Down Expand Up @@ -42,6 +44,42 @@ pub struct GuestInput {
pub taiko: TaikoGuestInput,
}

/// External aggregation input.
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AggregationGuestInput {
/// All block proofs to prove
pub proofs: Vec<Proof>,
}

/// The raw proof data necessary to verify a proof
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct RawProof {
/// The actual proof
pub proof: Vec<u8>,
/// The resulting hash
pub input: B256,
}

/// External aggregation input.
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct RawAggregationGuestInput {
/// All block proofs to prove
pub proofs: Vec<RawProof>,
}

/// External aggregation input.
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AggregationGuestOutput {
/// The resulting hash
pub hash: B256,
}

#[derive(Clone, Serialize, Deserialize)]
pub struct ZkAggregationGuestInput {
pub image_id: [u32; 8],
pub block_inputs: Vec<B256>,
}

impl From<(Block, Header, ChainSpec, TaikoGuestInput)> for GuestInput {
fn from(
(block, parent_header, chain_spec, taiko): (Block, Header, ChainSpec, TaikoGuestInput),
Expand Down
21 changes: 21 additions & 0 deletions lib/src/protocol_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,27 @@ fn bytes_to_bytes32(input: &[u8]) -> [u8; 32] {
bytes
}

pub fn words_to_bytes_le(words: &[u32; 8]) -> [u8; 32] {
let mut bytes = [0u8; 32];
for i in 0..8 {
let word_bytes = words[i].to_le_bytes();
bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes);
}
bytes
}

pub fn aggregation_output_combine(public_inputs: Vec<B256>) -> Vec<u8> {
let mut output = Vec::with_capacity(public_inputs.len() * 32);
for public_input in public_inputs.iter() {
output.extend_from_slice(&public_input.0);
}
output
}

pub fn aggregation_output(program: B256, public_inputs: Vec<B256>) -> Vec<u8> {
aggregation_output_combine([vec![program], public_inputs].concat())
}

#[cfg(test)]
mod tests {
use alloy_primitives::{address, b256};
Expand Down
15 changes: 13 additions & 2 deletions lib/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use reth_primitives::{ChainId, B256};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

use crate::input::{GuestInput, GuestOutput};
use crate::input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput};

#[derive(thiserror::Error, Debug)]
pub enum ProverError {
Expand All @@ -26,13 +26,17 @@ pub type ProverResult<T, E = ProverError> = core::result::Result<T, E>;
pub type ProverConfig = serde_json::Value;
pub type ProofKey = (ChainId, B256, u8);

#[derive(Debug, Serialize, ToSchema, Deserialize, Default)]
#[derive(Clone, Debug, Serialize, ToSchema, Deserialize, Default)]
/// The response body of a proof request.
pub struct Proof {
/// The proof either TEE or ZK.
pub proof: Option<String>,
/// The public input
pub input: Option<B256>,
/// The TEE quote.
pub quote: Option<String>,
/// The TEE quote.
pub uuid: Option<String>,
/// The kzg proof.
pub kzg_proof: Option<String>,
}
Expand All @@ -58,5 +62,12 @@ pub trait Prover {
store: Option<&mut dyn IdWrite>,
) -> ProverResult<Proof>;

async fn aggregate(
input: AggregationGuestInput,
output: &AggregationGuestOutput,
config: &ProverConfig,
store: Option<&mut dyn IdWrite>,
) -> ProverResult<Proof>;

async fn cancel(proof_key: ProofKey, read: Box<&mut dyn IdStore>) -> ProverResult<()>;
}
5 changes: 4 additions & 1 deletion provers/risc0/builder/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::path::PathBuf;

fn main() {
let pipeline = Risc0Pipeline::new("provers/risc0/guest", "release");
pipeline.bins(&["risc0-guest"], "provers/risc0/driver/src/methods");
pipeline.bins(
&["risc0-guest", "risc0-aggregation"],
"provers/risc0/driver/src/methods",
);
#[cfg(feature = "test")]
pipeline.tests(&["risc0-guest"], "provers/risc0/driver/src/methods");
#[cfg(feature = "bench")]
Expand Down
Loading
Loading