diff --git a/CHANGELOG.md b/CHANGELOG.md index af053797..8b0a00d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## [1.2.0](https://github.com/taikoxyz/raiko/compare/v1.1.0...v1.2.0) (2024-09-20) + + +### Features + +* **raiko:** make raiko-zk docker image ([#374](https://github.com/taikoxyz/raiko/issues/374)) ([65ff9a4](https://github.com/taikoxyz/raiko/commit/65ff9a4935ac66f0c21785a0b8415313942bda82)) +* **raiko:** traversal to find inclusion block if none inclusion number is sent ([#377](https://github.com/taikoxyz/raiko/issues/377)) ([c2b0db5](https://github.com/taikoxyz/raiko/commit/c2b0db5a61e920840f9de083de8684a8375e51b3)) +* **sgx:** add wallet to provider builder when register instance ([#369](https://github.com/taikoxyz/raiko/issues/369)) ([a250edf](https://github.com/taikoxyz/raiko/commit/a250edf2ca42d5481ba92d97ca6ade5b46bb536c)) + + +### Bug Fixes + +* **raiko:** refine error return ([#378](https://github.com/taikoxyz/raiko/issues/378)) ([f4f818d](https://github.com/taikoxyz/raiko/commit/f4f818d43a33ba1caf95cb1db4160ba90824eb2d)) +* **script:** output build message and skip `pos` flag ([#367](https://github.com/taikoxyz/raiko/issues/367)) ([2c881dc](https://github.com/taikoxyz/raiko/commit/2c881dc22d5df553bffc24f8bbac6a86e2fd9688)) + ## [1.1.0](https://github.com/taikoxyz/raiko/compare/v1.0.0...v1.1.0) (2024-09-13) diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index 962baf1d..63ad4114 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -11,7 +11,7 @@ use raiko_lib::{ use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_with::{serde_as, DisplayFromStr}; -use std::{collections::HashMap, path::Path, str::FromStr}; +use std::{collections::HashMap, fmt::Display, path::Path, str::FromStr}; use utoipa::ToSchema; #[derive(Debug, thiserror::Error, ToSchema)] @@ -345,7 +345,7 @@ pub struct ProofRequestOpt { pub prover_args: ProverSpecificOpts, } -#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args)] +#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, Args, PartialEq, Eq, Hash)] pub struct ProverSpecificOpts { /// Native prover specific options. pub native: Option, @@ -442,14 +442,122 @@ impl TryFrom for ProofRequest { } } -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema)] +#[serde(default)] /// A request for proof aggregation of multiple proofs. pub struct AggregationRequest { - /// All the proofs to verify + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. + pub block_numbers: Vec<(u64, Option)>, + /// The network to generate the proof for. + pub network: Option, + /// The L1 network to generate the proof for. + pub l1_network: Option, + // Graffiti. + pub graffiti: Option, + /// The protocol instance data. + pub prover: Option, + /// The proof type. + pub proof_type: Option, + /// Blob proof type. + pub blob_proof_type: Option, + #[serde(flatten)] + /// Any additional prover params in JSON format. + pub prover_args: ProverSpecificOpts, +} + +impl AggregationRequest { + /// Merge proof request options into aggregation request options. + pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> { + let this = serde_json::to_value(&self)?; + let mut opts = serde_json::to_value(opts)?; + merge(&mut opts, &this); + *self = serde_json::from_value(opts)?; + Ok(()) + } +} + +impl From for Vec { + fn from(value: AggregationRequest) -> Self { + value + .block_numbers + .iter() + .map( + |&(block_number, l1_inclusion_block_number)| ProofRequestOpt { + block_number: Some(block_number), + l1_inclusion_block_number, + network: value.network.clone(), + l1_network: value.l1_network.clone(), + graffiti: value.graffiti.clone(), + prover: value.prover.clone(), + proof_type: value.proof_type.clone(), + blob_proof_type: value.blob_proof_type.clone(), + prover_args: value.prover_args.clone(), + }, + ) + .collect() + } +} + +impl From for AggregationRequest { + fn from(value: ProofRequestOpt) -> Self { + let block_numbers = if let Some(block_number) = value.block_number { + vec![(block_number, value.l1_inclusion_block_number)] + } else { + vec![] + }; + + Self { + block_numbers, + network: value.network, + l1_network: value.l1_network, + graffiti: value.graffiti, + prover: value.prover, + proof_type: value.proof_type, + blob_proof_type: value.blob_proof_type, + prover_args: value.prover_args, + } + } +} + +#[derive(Default, Clone, Serialize, Deserialize, Debug, ToSchema, PartialEq, Eq, Hash)] +#[serde(default)] +/// A request for proof aggregation of multiple proofs. +pub struct AggregationOnlyRequest { + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. pub proofs: Vec, /// The proof type. - pub proof_type: ProofType, - /// Additional prover params. - pub prover_args: HashMap, + pub proof_type: Option, + #[serde(flatten)] + /// Any additional prover params in JSON format. + pub prover_args: ProverSpecificOpts, +} + +impl Display for AggregationOnlyRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&format!( + "AggregationOnlyRequest {{ {:?}, {:?} }}", + self.proof_type, self.prover_args + )) + } +} + +impl From<(AggregationRequest, Vec)> for AggregationOnlyRequest { + fn from((request, proofs): (AggregationRequest, Vec)) -> Self { + Self { + proofs, + proof_type: request.proof_type, + prover_args: request.prover_args, + } + } +} + +impl AggregationOnlyRequest { + /// Merge proof request options into aggregation request options. + pub fn merge(&mut self, opts: &ProofRequestOpt) -> RaikoResult<()> { + let this = serde_json::to_value(&self)?; + let mut opts = serde_json::to_value(opts)?; + merge(&mut opts, &this); + *self = serde_json::from_value(opts)?; + Ok(()) + } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 47c8d20c..db240382 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -455,8 +455,8 @@ mod tests { .unwrap(); let proof_request = ProofRequest { - l1_inclusion_block_number: 0, block_number, + l1_inclusion_block_number: 0, network, graffiti: B256::ZERO, prover: Address::ZERO, diff --git a/core/src/preflight/util.rs b/core/src/preflight/util.rs index fd329f02..889134d9 100644 --- a/core/src/preflight/util.rs +++ b/core/src/preflight/util.rs @@ -1,6 +1,6 @@ -use alloy_primitives::{hex, Log, B256}; +use alloy_primitives::{hex, Log as LogStruct, B256}; use alloy_provider::{Provider, ReqwestProvider}; -use alloy_rpc_types::{Filter, Header, Transaction as AlloyRpcTransaction}; +use alloy_rpc_types::{Filter, Header, Log, Transaction as AlloyRpcTransaction}; use alloy_sol_types::{SolCall, SolEvent}; use anyhow::{anyhow, bail, ensure, Result}; use kzg::kzg_types::ZFr; @@ -81,58 +81,67 @@ pub async fn prepare_taiko_chain_input( .first() .ok_or_else(|| RaikoError::Preflight("No anchor tx in the block".to_owned()))?; + // get anchor block num and state root let fork = taiko_chain_spec.active_fork(block.number, block.timestamp)?; - info!("current taiko chain fork: {fork:?}"); - - let (l1_state_block_number, l1_inclusion_block_number) = match fork { + let (anchor_block_height, anchor_state_root) = match fork { SpecId::ONTAKE => { let anchor_call = decode_anchor_ontake(anchor_tx.input())?; - ( - anchor_call._anchorBlockId, - l1_inclusion_block_number.unwrap_or(anchor_call._anchorBlockId + 1), - ) + (anchor_call._anchorBlockId, anchor_call._anchorStateRoot) } _ => { let anchor_call = decode_anchor(anchor_tx.input())?; - ( - anchor_call.l1BlockId, - l1_inclusion_block_number.unwrap_or(anchor_call.l1BlockId + 1), - ) + (anchor_call.l1BlockId, anchor_call.l1StateRoot) } }; - debug!( - "anchor L1 block id: {l1_state_block_number:?}, l1 inclusion block id: {l1_inclusion_block_number:?}" - ); - // Get the L1 block in which the L2 block was included so we can fetch the DA data. - // Also get the L1 state block header so that we can prove the L1 state root. + // // Get the L1 block in which the L2 block was included so we can fetch the DA data. + // // Also get the L1 state block header so that we can prove the L1 state root. let provider_l1 = RpcBlockDataProvider::new(&l1_chain_spec.rpc, block_number)?; + info!("current taiko chain fork: {fork:?}"); + + let (l1_inclusion_block_number, proposal_tx, block_proposed) = + if let Some(l1_block_number) = l1_inclusion_block_number { + // Get the block proposal data + get_block_proposed_event_by_height( + provider_l1.provider(), + taiko_chain_spec.clone(), + l1_block_number, + block_number, + fork, + ) + .await? + } else { + // traversal next 64 blocks to get proposal data + get_block_proposed_event_by_traversal( + provider_l1.provider(), + taiko_chain_spec.clone(), + anchor_block_height, + block_number, + fork, + ) + .await? + }; + let (l1_inclusion_header, l1_state_header) = get_headers( &provider_l1, - (l1_inclusion_block_number, l1_state_block_number), + (l1_inclusion_block_number, anchor_block_height), ) .await?; - + assert_eq!(anchor_state_root, l1_state_header.state_root); let l1_state_block_hash = l1_state_header.hash.ok_or_else(|| { RaikoError::Preflight("No L1 state block hash for the requested block".to_owned()) })?; - - debug!("l1_state_root_block hash: {l1_state_block_hash:?}"); - let l1_inclusion_block_hash = l1_inclusion_header.hash.ok_or_else(|| { RaikoError::Preflight("No L1 inclusion block hash for the requested block".to_owned()) })?; - - // Get the block proposal data - let (proposal_tx, block_proposed) = get_block_proposed_event( - provider_l1.provider(), - taiko_chain_spec.clone(), + info!( + "L1 inclusion block number: {:?}, hash: {:?}. L1 state block number: {:?}, hash: {:?}", + l1_inclusion_block_number, l1_inclusion_block_hash, - block_number, - fork, - ) - .await?; + l1_state_header.number, + l1_state_block_hash + ); // Fetch the tx data from either calldata or blobdata let (tx_data, blob_commitment, blob_proof) = if block_proposed.blob_used() { @@ -225,31 +234,39 @@ pub async fn get_tx_data( Ok((blob, Some(commitment.to_vec()), blob_proof)) } +pub async fn filter_blockchain_event( + provider: &ReqwestProvider, + gen_block_event_filter: impl Fn() -> Filter, +) -> Result> { + // Setup the filter to get the relevant events + let filter = gen_block_event_filter(); + // Now fetch the events + Ok(provider.get_logs(&filter).await?) +} + pub async fn get_calldata_txlist_event( provider: &ReqwestProvider, chain_spec: ChainSpec, block_hash: B256, l2_block_number: u64, ) -> Result<(AlloyRpcTransaction, CalldataTxList)> { - // Get the address that emitted the event + // // Get the address that emitted the event let Some(l1_address) = chain_spec.l1_contract else { bail!("No L1 contract address in the chain spec"); }; - // Get the event signature (value can differ between chains) - let event_signature = CalldataTxList::SIGNATURE_HASH; - // Setup the filter to get the relevant events - let filter = Filter::new() - .address(l1_address) - .at_block_hash(block_hash) - .event_signature(event_signature); - // Now fetch the events - let logs = provider.get_logs(&filter).await?; + let logs = filter_blockchain_event(provider, || { + Filter::new() + .address(l1_address) + .at_block_hash(block_hash) + .event_signature(CalldataTxList::SIGNATURE_HASH) + }) + .await?; // Run over the logs returned to find the matching event for the specified L2 block number // (there can be multiple blocks proposed in the same block and even same tx) for log in logs { - let Some(log_struct) = Log::new( + let Some(log_struct) = LogStruct::new( log.address(), log.topics().to_vec(), log.data().data.clone(), @@ -273,13 +290,20 @@ pub async fn get_calldata_txlist_event( bail!("No BlockProposedV2 event found for block {l2_block_number}"); } -pub async fn get_block_proposed_event( +pub enum EventFilterConditioin { + #[allow(dead_code)] + Hash(B256), + Height(u64), + Range((u64, u64)), +} + +pub async fn filter_block_proposed_event( provider: &ReqwestProvider, chain_spec: ChainSpec, - block_hash: B256, + filter_condition: EventFilterConditioin, l2_block_number: u64, fork: SpecId, -) -> Result<(AlloyRpcTransaction, BlockProposedFork)> { +) -> Result<(u64, AlloyRpcTransaction, BlockProposedFork)> { // Get the address that emitted the event let Some(l1_address) = chain_spec.l1_contract else { bail!("No L1 contract address in the chain spec"); @@ -291,24 +315,35 @@ pub async fn get_block_proposed_event( _ => BlockProposed::SIGNATURE_HASH, }; // Setup the filter to get the relevant events - let filter = Filter::new() - .address(l1_address) - .at_block_hash(block_hash) - .event_signature(event_signature); - // Now fetch the events - let logs = provider.get_logs(&filter).await?; + let logs = filter_blockchain_event(provider, || match filter_condition { + EventFilterConditioin::Hash(block_hash) => Filter::new() + .address(l1_address) + .at_block_hash(block_hash) + .event_signature(event_signature), + EventFilterConditioin::Height(block_number) => Filter::new() + .address(l1_address) + .from_block(block_number) + .to_block(block_number + 1) + .event_signature(event_signature), + EventFilterConditioin::Range((from_block_number, to_block_number)) => Filter::new() + .address(l1_address) + .from_block(from_block_number) + .to_block(to_block_number) + .event_signature(event_signature), + }) + .await?; // Run over the logs returned to find the matching event for the specified L2 block number // (there can be multiple blocks proposed in the same block and even same tx) for log in logs { - let Some(log_struct) = Log::new( + let Some(log_struct) = LogStruct::new( log.address(), log.topics().to_vec(), log.data().data.clone(), ) else { bail!("Could not create log") }; - let (block_id, data) = match fork { + let (block_id, block_propose_event) = match fork { SpecId::ONTAKE => { let event = BlockProposedV2::decode_log(&log_struct, false) .map_err(|_| RaikoError::Anyhow(anyhow!("Could not decode log")))?; @@ -330,7 +365,7 @@ pub async fn get_block_proposed_event( .await .expect("couldn't query the propose tx") .expect("Could not find the propose tx"); - return Ok((tx, data)); + return Ok((log.block_number.unwrap(), tx, block_propose_event)); } } @@ -339,6 +374,57 @@ pub async fn get_block_proposed_event( )) } +pub async fn _get_block_proposed_event_by_hash( + provider: &ReqwestProvider, + chain_spec: ChainSpec, + l1_inclusion_block_hash: B256, + l2_block_number: u64, + fork: SpecId, +) -> Result<(u64, AlloyRpcTransaction, BlockProposedFork)> { + filter_block_proposed_event( + provider, + chain_spec, + EventFilterConditioin::Hash(l1_inclusion_block_hash), + l2_block_number, + fork, + ) + .await +} + +pub async fn get_block_proposed_event_by_height( + provider: &ReqwestProvider, + chain_spec: ChainSpec, + l1_inclusion_block_number: u64, + l2_block_number: u64, + fork: SpecId, +) -> Result<(u64, AlloyRpcTransaction, BlockProposedFork)> { + filter_block_proposed_event( + provider, + chain_spec, + EventFilterConditioin::Height(l1_inclusion_block_number), + l2_block_number, + fork, + ) + .await +} + +pub async fn get_block_proposed_event_by_traversal( + provider: &ReqwestProvider, + chain_spec: ChainSpec, + l1_anchor_block_number: u64, + l2_block_number: u64, + fork: SpecId, +) -> Result<(u64, AlloyRpcTransaction, BlockProposedFork)> { + filter_block_proposed_event( + provider, + chain_spec, + EventFilterConditioin::Range((l1_anchor_block_number + 1, l1_anchor_block_number + 65)), + l2_block_number, + fork, + ) + .await +} + pub async fn get_block_and_parent_data( provider: &BDP, block_number: u64, diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 7ef85385..5fc5d82b 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -182,5 +182,5 @@ if [[ -n $ZK ]]; then update_raiko_sgx_instance_id $RAIKO_CONF_BASE_CONFIG update_docker_chain_specs $RAIKO_CONF_CHAIN_SPECS - RUST_LOG=debug /opt/raiko/bin/raiko-host "$@" + /opt/raiko/bin/raiko-host "$@" fi \ No newline at end of file diff --git a/host/src/interfaces.rs b/host/src/interfaces.rs index 728d7710..330446ef 100644 --- a/host/src/interfaces.rs +++ b/host/src/interfaces.rs @@ -121,12 +121,12 @@ impl From for TaskStatus { | HostError::JoinHandle(_) | HostError::InvalidAddress(_) | HostError::InvalidRequestConfig(_) => unreachable!(), - HostError::Conversion(_) - | HostError::Serde(_) - | HostError::Core(_) - | HostError::Anyhow(_) - | HostError::FeatureNotSupportedError(_) - | HostError::Io(_) => TaskStatus::UnspecifiedFailureReason, + HostError::Conversion(e) => TaskStatus::NonDbFailure(e), + HostError::Serde(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Core(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Anyhow(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::FeatureNotSupportedError(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Io(e) => TaskStatus::NonDbFailure(e.to_string()), HostError::RPC(_) => TaskStatus::NetworkFailure, HostError::Guest(_) => TaskStatus::ProofFailure_Generic, HostError::TaskManager(_) => TaskStatus::SqlDbCorruption, @@ -142,12 +142,12 @@ impl From<&HostError> for TaskStatus { | HostError::JoinHandle(_) | HostError::InvalidAddress(_) | HostError::InvalidRequestConfig(_) => unreachable!(), - HostError::Conversion(_) - | HostError::Serde(_) - | HostError::Core(_) - | HostError::Anyhow(_) - | HostError::FeatureNotSupportedError(_) - | HostError::Io(_) => TaskStatus::UnspecifiedFailureReason, + HostError::Conversion(e) => TaskStatus::NonDbFailure(e.to_owned()), + HostError::Serde(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Core(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Anyhow(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::FeatureNotSupportedError(e) => TaskStatus::NonDbFailure(e.to_string()), + HostError::Io(e) => TaskStatus::NonDbFailure(e.to_string()), HostError::RPC(_) => TaskStatus::NetworkFailure, HostError::Guest(_) => TaskStatus::ProofFailure_Generic, HostError::TaskManager(_) => TaskStatus::SqlDbCorruption, diff --git a/host/src/lib.rs b/host/src/lib.rs index a4df64dc..6927314b 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -4,7 +4,7 @@ use anyhow::Context; use cap::Cap; use clap::Parser; use raiko_core::{ - interfaces::{ProofRequest, ProofRequestOpt}, + interfaces::{AggregationOnlyRequest, ProofRequest, ProofRequestOpt}, merge, }; use raiko_lib::consts::SupportedChainSpecs; @@ -152,6 +152,8 @@ pub struct ProverState { pub enum Message { Cancel(TaskDescriptor), Task(ProofRequest), + CancelAggregate(AggregationOnlyRequest), + Aggregate(AggregationOnlyRequest), } impl From<&ProofRequest> for Message { @@ -166,6 +168,12 @@ impl From<&TaskDescriptor> for Message { } } +impl From for Message { + fn from(value: AggregationOnlyRequest) -> Self { + Self::Aggregate(value) + } +} + impl ProverState { pub fn init() -> HostResult { // Read the command line arguments; diff --git a/host/src/proof.rs b/host/src/proof.rs index 3fc65a32..215a5b4f 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -1,16 +1,19 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, str::FromStr, sync::Arc}; +use anyhow::anyhow; use raiko_core::{ - interfaces::{ProofRequest, RaikoError}, + interfaces::{AggregationOnlyRequest, ProofRequest, ProofType, RaikoError}, provider::{get_task_data, rpc::RpcBlockDataProvider}, Raiko, }; use raiko_lib::{ consts::SupportedChainSpecs, + input::{AggregationGuestInput, AggregationGuestOutput}, prover::{IdWrite, Proof}, Measurement, }; use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrapper, TaskStatus}; +use reth_primitives::B256; use tokio::{ select, sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore}, @@ -33,6 +36,7 @@ pub struct ProofActor { opts: Opts, chain_specs: SupportedChainSpecs, tasks: Arc>>, + aggregate_tasks: Arc>>, receiver: Receiver, } @@ -41,9 +45,14 @@ impl ProofActor { let tasks = Arc::new(Mutex::new( HashMap::::new(), )); + let aggregate_tasks = Arc::new(Mutex::new(HashMap::< + AggregationOnlyRequest, + CancellationToken, + >::new())); Self { tasks, + aggregate_tasks, opts, chain_specs, receiver, @@ -110,6 +119,74 @@ impl ProofActor { info!("Task cancelled"); } result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => { + match result { + Ok(status) => { + info!("Host handling message: {status:?}"); + } + Err(error) => { + error!("Worker failed due to: {error:?}"); + } + }; + } + } + let mut tasks = tasks.lock().await; + tasks.remove(&key); + }); + } + + pub async fn cancel_aggregation_task( + &mut self, + request: AggregationOnlyRequest, + ) -> HostResult<()> { + let tasks_map = self.aggregate_tasks.lock().await; + let Some(task) = tasks_map.get(&request) else { + warn!("No task with those keys to cancel"); + return Ok(()); + }; + + // TODO:(petar) implement cancel_proof_aggregation + // let mut manager = get_task_manager(&self.opts.clone().into()); + // let proof_type = ProofType::from_str( + // request + // .proof_type + // .as_ref() + // .ok_or_else(|| anyhow!("No proof type"))?, + // )?; + // proof_type + // .cancel_proof_aggregation(request, Box::new(&mut manager)) + // .await + // .or_else(|e| { + // if e.to_string().contains("No data for query") { + // warn!("Task already cancelled or not yet started!"); + // Ok(()) + // } else { + // Err::<(), HostError>(e.into()) + // } + // })?; + task.cancel(); + Ok(()) + } + + pub async fn run_aggregate( + &mut self, + request: AggregationOnlyRequest, + _permit: OwnedSemaphorePermit, + ) { + let cancel_token = CancellationToken::new(); + + let mut tasks = self.aggregate_tasks.lock().await; + tasks.insert(request.clone(), cancel_token.clone()); + + let request_clone = request.clone(); + let tasks = self.aggregate_tasks.clone(); + let opts = self.opts.clone(); + + tokio::spawn(async move { + select! { + _ = cancel_token.cancelled() => { + info!("Task cancelled"); + } + result = Self::handle_aggregate(request_clone, &opts) => { match result { Ok(()) => { info!("Host handling message"); @@ -121,7 +198,7 @@ impl ProofActor { } } let mut tasks = tasks.lock().await; - tasks.remove(&key); + tasks.remove(&request); }); } @@ -142,6 +219,18 @@ impl ProofActor { .expect("Couldn't acquire permit"); self.run_task(proof_request, permit).await; } + Message::CancelAggregate(request) => { + if let Err(error) = self.cancel_aggregation_task(request).await { + error!("Failed to cancel task: {error}") + } + } + Message::Aggregate(request) => { + let permit = Arc::clone(&semaphore) + .acquire_owned() + .await + .expect("Couldn't acquire permit"); + self.run_aggregate(request, permit).await; + } } } } @@ -151,14 +240,14 @@ impl ProofActor { key: TaskDescriptor, opts: &Opts, chain_specs: &SupportedChainSpecs, - ) -> HostResult<()> { + ) -> HostResult { let mut manager = get_task_manager(&opts.clone().into()); let status = manager.get_task_proving_status(&key).await?; if let Some(latest_status) = status.iter().last() { if !matches!(latest_status.0, TaskStatus::Registered) { - return Ok(()); + return Ok(latest_status.0.clone()); } } @@ -176,9 +265,57 @@ impl ProofActor { }; manager - .update_task_progress(key, status, proof.as_deref()) + .update_task_progress(key, status.clone(), proof.as_deref()) + .await + .map_err(HostError::from)?; + Ok(status) + } + + pub async fn handle_aggregate(request: AggregationOnlyRequest, opts: &Opts) -> HostResult<()> { + let mut manager = get_task_manager(&opts.clone().into()); + + let status = manager + .get_aggregation_task_proving_status(&request) + .await?; + + if let Some(latest_status) = status.iter().last() { + if !matches!(latest_status.0, TaskStatus::Registered) { + return Ok(()); + } + } + + manager + .update_aggregation_task_progress(&request, TaskStatus::WorkInProgress, None) + .await?; + let proof_type = ProofType::from_str( + request + .proof_type + .as_ref() + .ok_or_else(|| anyhow!("No proof type"))?, + )?; + let input = AggregationGuestInput { + proofs: request.clone().proofs, + }; + let output = AggregationGuestOutput { hash: B256::ZERO }; + let config = serde_json::to_value(request.clone().prover_args)?; + let mut manager = get_task_manager(&opts.clone().into()); + + let (status, proof) = match proof_type + .aggregate_proofs(input, &output, &config, Some(&mut manager)) .await - .map_err(|e| e.into()) + { + Err(error) => { + error!("{error}"); + (HostError::from(error).into(), None) + } + Ok(proof) => (TaskStatus::Success, Some(serde_json::to_vec(&proof)?)), + }; + + manager + .update_aggregation_task_progress(&request, status, proof.as_deref()) + .await?; + + Ok(()) } } diff --git a/host/src/server/api/mod.rs b/host/src/server/api/mod.rs index 4aa8e098..45be92f1 100644 --- a/host/src/server/api/mod.rs +++ b/host/src/server/api/mod.rs @@ -18,6 +18,7 @@ use crate::ProverState; pub mod v1; pub mod v2; +pub mod v3; pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Router { let cors = CorsLayer::new() @@ -37,11 +38,13 @@ pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Rout let v1_api = v1::create_router(concurrency_limit); let v2_api = v2::create_router(); + let v3_api = v3::create_router(); let router = Router::new() .nest("/v1", v1_api) - .nest("/v2", v2_api.clone()) - .merge(v2_api) + .nest("/v2", v2_api) + .nest("/v3", v3_api.clone()) + .merge(v3_api) .layer(middleware) .layer(middleware::from_fn(check_max_body_size)) .layer(trace) @@ -58,7 +61,7 @@ pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Rout } pub fn create_docs() -> utoipa::openapi::OpenApi { - v2::create_docs() + v3::create_docs() } async fn check_max_body_size(req: Request, next: Next) -> Response { diff --git a/host/src/server/api/v2/mod.rs b/host/src/server/api/v2/mod.rs index 7c32b4ff..f4fc046a 100644 --- a/host/src/server/api/v2/mod.rs +++ b/host/src/server/api/v2/mod.rs @@ -11,7 +11,7 @@ use crate::{ ProverState, }; -mod proof; +pub mod proof; #[derive(OpenApi)] #[openapi( @@ -84,8 +84,14 @@ impl From> for Status { impl From for Status { fn from(status: TaskStatus) -> Self { - Self::Ok { - data: ProofResponse::Status { status }, + match status { + TaskStatus::Success | TaskStatus::WorkInProgress | TaskStatus::Registered => Self::Ok { + data: ProofResponse::Status { status }, + }, + _ => Self::Error { + error: "task_failed".to_string(), + message: format!("Task failed with status: {status:?}"), + }, } } } diff --git a/host/src/server/api/v2/proof/mod.rs b/host/src/server/api/v2/proof/mod.rs index ce089375..d57335cd 100644 --- a/host/src/server/api/v2/proof/mod.rs +++ b/host/src/server/api/v2/proof/mod.rs @@ -11,10 +11,10 @@ use crate::{ Message, ProverState, }; -mod cancel; -mod list; -mod prune; -mod report; +pub mod cancel; +pub mod list; +pub mod prune; +pub mod report; #[utoipa::path(post, path = "/proof", tag = "Proving", @@ -98,7 +98,7 @@ async fn proof_handler( Ok(proof.into()) } // For all other statuses just return the status. - status => Ok((*status).into()), + status => Ok(status.clone().into()), } } diff --git a/host/src/server/api/v3/mod.rs b/host/src/server/api/v3/mod.rs new file mode 100644 index 00000000..faf46b61 --- /dev/null +++ b/host/src/server/api/v3/mod.rs @@ -0,0 +1,172 @@ +use axum::{response::IntoResponse, Json, Router}; +use raiko_lib::prover::Proof; +use raiko_tasks::TaskStatus; +use serde::{Deserialize, Serialize}; +use utoipa::{OpenApi, ToSchema}; +use utoipa_scalar::{Scalar, Servable}; +use utoipa_swagger_ui::SwaggerUi; + +use crate::{ + server::api::v1::{self, GuestOutputDoc}, + ProverState, +}; + +mod proof; + +#[derive(OpenApi)] +#[openapi( + info( + title = "Raiko Proverd Server API", + version = "3.0", + description = "Raiko Proverd Server API", + contact( + name = "API Support", + url = "https://community.taiko.xyz", + email = "info@taiko.xyz", + ), + license( + name = "MIT", + url = "https://github.com/taikoxyz/raiko/blob/main/LICENSE" + ), + ), + components( + schemas( + raiko_core::interfaces::ProofRequestOpt, + raiko_core::interfaces::ProverSpecificOpts, + crate::interfaces::HostError, + GuestOutputDoc, + ProofResponse, + TaskStatus, + CancelStatus, + PruneStatus, + Proof, + Status, + ) + ), + tags( + (name = "Proving", description = "Routes that handle proving requests"), + (name = "Health", description = "Routes that report the server health status"), + (name = "Metrics", description = "Routes that give detailed insight into the server") + ) +)] +/// The root API struct which is generated from the `OpenApi` derive macro. +pub struct Docs; + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[serde(untagged)] +pub enum ProofResponse { + Status { + /// The status of the submitted task. + status: TaskStatus, + }, + Proof { + /// The proof. + proof: Proof, + }, +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[serde(tag = "status", rename_all = "lowercase")] +pub enum Status { + Ok { data: ProofResponse }, + Error { error: String, message: String }, +} + +impl From> for Status { + fn from(proof: Vec) -> Self { + Self::Ok { + data: ProofResponse::Proof { + proof: serde_json::from_slice(&proof).unwrap_or_default(), + }, + } + } +} + +impl From for Status { + fn from(proof: Proof) -> Self { + Self::Ok { + data: ProofResponse::Proof { proof }, + } + } +} + +impl From for Status { + fn from(status: TaskStatus) -> Self { + match status { + TaskStatus::Success | TaskStatus::WorkInProgress | TaskStatus::Registered => Self::Ok { + data: ProofResponse::Status { status }, + }, + _ => Self::Error { + error: "task_failed".to_string(), + message: format!("Task failed with status: {status:?}"), + }, + } + } +} + +impl IntoResponse for Status { + fn into_response(self) -> axum::response::Response { + Json(serde_json::to_value(self).unwrap()).into_response() + } +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[serde(tag = "status", rename_all = "lowercase")] +/// Status of cancellation request. +/// Can be `ok` for a successful cancellation or `error` with message and error type for errors. +pub enum CancelStatus { + /// Cancellation was successful. + Ok, + /// Cancellation failed. + Error { error: String, message: String }, +} + +impl IntoResponse for CancelStatus { + fn into_response(self) -> axum::response::Response { + Json(serde_json::to_value(self).unwrap()).into_response() + } +} + +#[derive(Debug, Serialize, ToSchema, Deserialize)] +#[serde(tag = "status", rename_all = "lowercase")] +/// Status of prune request. +/// Can be `ok` for a successful prune or `error` with message and error type for errors. +pub enum PruneStatus { + /// Prune was successful. + Ok, + /// Prune failed. + Error { error: String, message: String }, +} + +impl IntoResponse for PruneStatus { + fn into_response(self) -> axum::response::Response { + Json(serde_json::to_value(self).unwrap()).into_response() + } +} + +#[must_use] +pub fn create_docs() -> utoipa::openapi::OpenApi { + [ + v1::health::create_docs(), + v1::metrics::create_docs(), + proof::create_docs(), + ] + .into_iter() + .fold(Docs::openapi(), |mut doc, sub_doc| { + doc.merge(sub_doc); + doc + }) +} + +pub fn create_router() -> Router { + let docs = create_docs(); + + Router::new() + // Only add the concurrency limit to the proof route. We want to still be able to call + // healthchecks and metrics to have insight into the system. + .nest("/proof", proof::create_router()) + .nest("/health", v1::health::create_router()) + .nest("/metrics", v1::metrics::create_router()) + .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", docs.clone())) + .merge(Scalar::with_url("/scalar", docs)) +} diff --git a/host/src/server/api/v3/proof/aggregate.rs b/host/src/server/api/v3/proof/aggregate.rs new file mode 100644 index 00000000..3bbffa00 --- /dev/null +++ b/host/src/server/api/v3/proof/aggregate.rs @@ -0,0 +1,114 @@ +use std::str::FromStr; + +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::interfaces::{AggregationOnlyRequest, ProofType}; +use raiko_tasks::{TaskManager, TaskStatus}; +use utoipa::OpenApi; + +use crate::{ + interfaces::HostResult, + metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, + server::api::v3::Status, + Message, ProverState, +}; + +#[utoipa::path(post, path = "/proof/aggregate", + tag = "Proving", + request_body = AggregationRequest, + responses ( + (status = 200, description = "Successfully submitted proof aggregation task, queried aggregation tasks in progress or retrieved aggregated proof.", body = Status) + ) +)] +#[debug_handler(state = ProverState)] +/// Submit a proof aggregation task with requested config, get task status or get proof value. +/// +/// Accepts a proof request and creates a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn aggregation_handler( + State(prover_state): State, + Json(mut aggregation_request): Json, +) -> HostResult { + inc_current_req(); + // Override the existing proof request config from the config file and command line + // options with the request from the client. + aggregation_request.merge(&prover_state.request_config())?; + + let proof_type = ProofType::from_str( + aggregation_request + .proof_type + .as_deref() + .unwrap_or_default(), + )?; + inc_host_req_count(0); + inc_guest_req_count(&proof_type, 0); + + if aggregation_request.proofs.is_empty() { + return Err(anyhow::anyhow!("No proofs provided").into()); + } + + let mut manager = prover_state.task_manager(); + + let status = manager + .get_aggregation_task_proving_status(&aggregation_request) + .await?; + + let Some((latest_status, ..)) = status.last() else { + // If there are no tasks with provided config, create a new one. + manager + .enqueue_aggregation_task(&aggregation_request) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request.clone()))?; + return Ok(Status::from(TaskStatus::Registered)); + }; + + match latest_status { + // If task has been cancelled add it to the queue again + TaskStatus::Cancelled + | TaskStatus::Cancelled_Aborted + | TaskStatus::Cancelled_NeverStarted + | TaskStatus::CancellationInProgress => { + manager + .update_aggregation_task_progress( + &aggregation_request, + TaskStatus::Registered, + None, + ) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request))?; + + Ok(Status::from(TaskStatus::Registered)) + } + // If the task has succeeded, return the proof. + TaskStatus::Success => { + let proof = manager + .get_aggregation_task_proof(&aggregation_request) + .await?; + + Ok(proof.into()) + } + // For all other statuses just return the status. + status => Ok(status.clone().into()), + } +} + +#[derive(OpenApi)] +#[openapi(paths(aggregation_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", post(aggregation_handler)) +} diff --git a/host/src/server/api/v3/proof/cancel.rs b/host/src/server/api/v3/proof/cancel.rs new file mode 100644 index 00000000..6e721c71 --- /dev/null +++ b/host/src/server/api/v3/proof/cancel.rs @@ -0,0 +1,76 @@ +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::{ + interfaces::{AggregationRequest, ProofRequest, ProofRequestOpt}, + provider::get_task_data, +}; +use raiko_tasks::{TaskDescriptor, TaskManager, TaskStatus}; +use utoipa::OpenApi; + +use crate::{interfaces::HostResult, server::api::v2::CancelStatus, Message, ProverState}; + +#[utoipa::path(post, path = "/proof/cancel", + tag = "Proving", + request_body = ProofRequestOpt, + responses ( + (status = 200, description = "Successfully cancelled proof task", body = CancelStatus) + ) +)] +#[debug_handler(state = ProverState)] +/// Cancel a proof task with requested config. +/// +/// Accepts a proof request and cancels a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn cancel_handler( + State(prover_state): State, + Json(mut aggregation_request): Json, +) -> HostResult { + // Override the existing proof request config from the config file and command line + // options with the request from the client. + aggregation_request.merge(&prover_state.request_config())?; + + let proof_request_opts: Vec = aggregation_request.into(); + + for opt in proof_request_opts { + let proof_request = ProofRequest::try_from(opt)?; + + let (chain_id, block_hash) = get_task_data( + &proof_request.network, + proof_request.block_number, + &prover_state.chain_specs, + ) + .await?; + + let key = TaskDescriptor::from(( + chain_id, + block_hash, + proof_request.proof_type, + proof_request.prover.clone().to_string(), + )); + + prover_state.task_channel.try_send(Message::from(&key))?; + + let mut manager = prover_state.task_manager(); + + manager + .update_task_progress(key, TaskStatus::Cancelled, None) + .await?; + } + + Ok(CancelStatus::Ok) +} + +#[derive(OpenApi)] +#[openapi(paths(cancel_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", post(cancel_handler)) +} diff --git a/host/src/server/api/v3/proof/mod.rs b/host/src/server/api/v3/proof/mod.rs new file mode 100644 index 00000000..b8f5e35d --- /dev/null +++ b/host/src/server/api/v3/proof/mod.rs @@ -0,0 +1,219 @@ +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::{ + interfaces::{AggregationOnlyRequest, AggregationRequest, ProofRequest, ProofRequestOpt}, + provider::get_task_data, +}; +use raiko_tasks::{TaskDescriptor, TaskManager, TaskStatus}; +use utoipa::OpenApi; + +use crate::{ + interfaces::HostResult, + metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, + server::api::{v2, v3::Status}, + Message, ProverState, +}; +use tracing::{debug, info}; + +mod aggregate; +mod cancel; + +#[utoipa::path(post, path = "/proof", + tag = "Proving", + request_body = AggregationRequest, + responses ( + (status = 200, description = "Successfully submitted proof task, queried tasks in progress or retrieved proof.", body = Status) + ) +)] +#[debug_handler(state = ProverState)] +/// Submit a proof aggregation task with requested config, get task status or get proof value. +/// +/// Accepts a proof request and creates a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn proof_handler( + State(prover_state): State, + Json(mut aggregation_request): Json, +) -> HostResult { + inc_current_req(); + // Override the existing proof request config from the config file and command line + // options with the request from the client. + aggregation_request.merge(&prover_state.request_config())?; + + let mut tasks = Vec::with_capacity(aggregation_request.block_numbers.len()); + + let proof_request_opts: Vec = aggregation_request.clone().into(); + + if proof_request_opts.is_empty() { + return Err(anyhow::anyhow!("No blocks for proving provided").into()); + } + + // Construct the actual proof request from the available configs. + for proof_request_opt in proof_request_opts { + let proof_request = ProofRequest::try_from(proof_request_opt)?; + + inc_host_req_count(proof_request.block_number); + inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); + + let (chain_id, blockhash) = get_task_data( + &proof_request.network, + proof_request.block_number, + &prover_state.chain_specs, + ) + .await?; + + let key = TaskDescriptor::from(( + chain_id, + blockhash, + proof_request.proof_type, + proof_request.prover.to_string(), + )); + + tasks.push((key, proof_request)); + } + + let mut manager = prover_state.task_manager(); + + let mut is_registered = false; + let mut is_success = true; + let mut statuses = Vec::with_capacity(tasks.len()); + + for (key, req) in tasks.iter() { + let status = manager.get_task_proving_status(key).await?; + + let Some((latest_status, ..)) = status.last() else { + // If there are no tasks with provided config, create a new one. + manager.enqueue_task(key).await?; + + prover_state.task_channel.try_send(Message::from(req))?; + is_registered = true; + continue; + }; + + match latest_status { + // If task has been cancelled add it to the queue again + TaskStatus::Cancelled + | TaskStatus::Cancelled_Aborted + | TaskStatus::Cancelled_NeverStarted + | TaskStatus::CancellationInProgress => { + manager + .update_task_progress(key.clone(), TaskStatus::Registered, None) + .await?; + + prover_state.task_channel.try_send(Message::from(req))?; + + is_registered = true; + is_success = false; + } + // If the task has succeeded, return the proof. + TaskStatus::Success => {} + // For all other statuses just return the status. + status => { + statuses.push(status.clone()); + is_registered = false; + is_success = false; + } + } + } + + if is_registered { + Ok(TaskStatus::Registered.into()) + } else if is_success { + info!("All tasks are successful, aggregating proofs"); + let mut proofs = Vec::with_capacity(tasks.len()); + for (task, req) in tasks { + let raw_proof = manager.get_task_proof(&task).await?; + let proof = serde_json::from_slice(&raw_proof)?; + debug!("req: {:?} gets proof: {:?}", req, proof); + proofs.push(proof); + } + + let aggregation_request = AggregationOnlyRequest { + proofs, + proof_type: aggregation_request.proof_type, + prover_args: aggregation_request.prover_args, + }; + + let status = manager + .get_aggregation_task_proving_status(&aggregation_request) + .await?; + + let Some((latest_status, ..)) = status.last() else { + // If there are no tasks with provided config, create a new one. + manager + .enqueue_aggregation_task(&aggregation_request) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request.clone()))?; + return Ok(Status::from(TaskStatus::Registered)); + }; + + match latest_status { + // If task has been cancelled add it to the queue again + TaskStatus::Cancelled + | TaskStatus::Cancelled_Aborted + | TaskStatus::Cancelled_NeverStarted + | TaskStatus::CancellationInProgress => { + manager + .update_aggregation_task_progress( + &aggregation_request, + TaskStatus::Registered, + None, + ) + .await?; + + prover_state + .task_channel + .try_send(Message::from(aggregation_request))?; + + Ok(Status::from(TaskStatus::Registered)) + } + // If the task has succeeded, return the proof. + TaskStatus::Success => { + let proof = manager + .get_aggregation_task_proof(&aggregation_request) + .await?; + + Ok(proof.into()) + } + // For all other statuses just return the status. + status => Ok(status.clone().into()), + } + } else { + let status = statuses.into_iter().collect::(); + Ok(status.into()) + } +} + +#[derive(OpenApi)] +#[openapi(paths(proof_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + [ + cancel::create_docs(), + aggregate::create_docs(), + v2::proof::report::create_docs(), + v2::proof::list::create_docs(), + v2::proof::prune::create_docs(), + ] + .into_iter() + .fold(Docs::openapi(), |mut docs, curr| { + docs.merge(curr); + docs + }) +} + +pub fn create_router() -> Router { + Router::new() + .route("/", post(proof_handler)) + .nest("/cancel", cancel::create_router()) + .nest("/aggregate", aggregate::create_router()) + .nest("/report", v2::proof::report::create_router()) + .nest("/list", v2::proof::list::create_router()) + .nest("/prune", v2::proof::prune::create_router()) +} diff --git a/lib/src/input.rs b/lib/src/input.rs index 44a18ef6..bb9c9ed9 100644 --- a/lib/src/input.rs +++ b/lib/src/input.rs @@ -73,7 +73,7 @@ pub struct AggregationGuestOutput { pub hash: B256, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ZkAggregationGuestInput { pub image_id: [u32; 8], pub block_inputs: Vec, diff --git a/lib/src/protocol_instance.rs b/lib/src/protocol_instance.rs index 786181cc..3f6271ef 100644 --- a/lib/src/protocol_instance.rs +++ b/lib/src/protocol_instance.rs @@ -18,7 +18,7 @@ use crate::{ }, CycleTracker, }; -use log::info; +use log::{debug, info}; use reth_evm_ethereum::taiko::ANCHOR_GAS_LIMIT; #[derive(Debug, Clone)] @@ -275,6 +275,18 @@ impl ProtocolInstance { pub fn instance_hash(&self) -> B256 { // packages/protocol/contracts/verifiers/libs/LibPublicInput.sol // "VERIFY_PROOF", _chainId, _verifierContract, _tran, _newInstance, _prover, _metaHash + debug!( + "calculate instance_hash from: + chain_id: {:?}, verifier: {:?}, transition: {:?}, sgx_instance: {:?}, + prover: {:?}, block_meta: {:?}, meta_hash: {:?}", + self.chain_id, + self.verifier_address, + self.transition.clone(), + self.sgx_instance, + self.prover, + self.block_metadata, + self.meta_hash(), + ); let data = ( "VERIFY_PROOF", self.chain_id, @@ -324,6 +336,15 @@ pub fn words_to_bytes_le(words: &[u32; 8]) -> [u8; 32] { bytes } +pub fn words_to_bytes_be(words: &[u32; 8]) -> [u8; 32] { + let mut bytes = [0u8; 32]; + for i in 0..8 { + let word_bytes = words[i].to_be_bytes(); + bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes); + } + bytes +} + pub fn aggregation_output_combine(public_inputs: Vec) -> Vec { let mut output = Vec::with_capacity(public_inputs.len() * 32); for public_input in public_inputs.iter() { diff --git a/lib/src/prover.rs b/lib/src/prover.rs index 5a1c7669..08de0229 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -26,7 +26,7 @@ pub type ProverResult = core::result::Result; pub type ProverConfig = serde_json::Value; pub type ProofKey = (ChainId, B256, u8); -#[derive(Clone, Debug, Serialize, ToSchema, Deserialize, Default)] +#[derive(Clone, Debug, Serialize, ToSchema, Deserialize, Default, PartialEq, Eq, Hash)] /// The response body of a proof request. pub struct Proof { /// The proof either TEE or ZK. diff --git a/pipeline/src/executor.rs b/pipeline/src/executor.rs index a46128a0..a5bb0a74 100644 --- a/pipeline/src/executor.rs +++ b/pipeline/src/executor.rs @@ -100,7 +100,8 @@ impl Executor { let elf = std::fs::read(&dest.join(&name.replace('_', "-")))?; let prover = CpuProver::new(); let key_pair = prover.setup(&elf); - println!("sp1 elf vk is: {}", key_pair.1.bytes32()); + println!("sp1 elf vk bn256 is: {}", key_pair.1.bytes32()); + println!("sp1 elf vk hash_bytes is: {}", hex::encode(&key_pair.1.hash_bytes())); } Ok(()) diff --git a/provers/risc0/driver/Cargo.toml b/provers/risc0/driver/Cargo.toml index a1f5e11e..3274acce 100644 --- a/provers/risc0/driver/Cargo.toml +++ b/provers/risc0/driver/Cargo.toml @@ -63,9 +63,9 @@ enable = [ "serde_json", "hex", "reqwest", - "lazy_static" + "lazy_static", ] cuda = ["risc0-zkvm?/cuda"] metal = ["risc0-zkvm?/metal"] bench = [] -bonsai-auto-scaling = [] \ No newline at end of file +bonsai-auto-scaling = [] diff --git a/provers/risc0/driver/src/bonsai.rs b/provers/risc0/driver/src/bonsai.rs index 40ee4f32..ba53f22c 100644 --- a/provers/risc0/driver/src/bonsai.rs +++ b/provers/risc0/driver/src/bonsai.rs @@ -1,6 +1,6 @@ use crate::{ methods::risc0_guest::RISC0_GUEST_ID, - snarks::{stark2snark, verify_groth16_snark}, + snarks::{stark2snark, verify_groth16_from_snark_receipt}, Risc0Response, }; use alloy_primitives::B256; @@ -198,8 +198,8 @@ pub async fn maybe_prove for Proof { fn from(value: Risc0Response) -> Self { Self { proof: Some(value.proof), - quote: None, + quote: Some(value.receipt), input: Some(value.input), uuid: Some(value.uuid), kzg_proof: None, @@ -128,13 +131,15 @@ impl Prover for Risc0Prover { async fn aggregate( input: AggregationGuestInput, - output: &AggregationGuestOutput, + _output: &AggregationGuestOutput, config: &ProverConfig, - id_store: Option<&mut dyn IdWrite>, + _id_store: Option<&mut dyn IdWrite>, ) -> ProverResult { - let mut id_store = id_store; let config = Risc0Param::deserialize(config.get("risc0").unwrap()).unwrap(); - let proof_key = (0, output.hash.clone(), RISC0_PROVER_CODE); + assert!( + config.snark && config.bonsai, + "Aggregation must be in bonsai snark mode" + ); // Extract the block proof receipts let assumptions: Vec = input @@ -151,59 +156,45 @@ impl Prover for Risc0Prover { .iter() .map(|proof| proof.input.unwrap()) .collect::>(); - // For bonsai - let assumptions_uuids: Vec = input - .proofs - .iter() - .map(|proof| proof.uuid.clone().unwrap()) - .collect::>(); - let input = ZkAggregationGuestInput { image_id: RISC0_GUEST_ID, block_inputs, }; - - debug!("elf code length: {}", RISC0_AGGREGATION_ELF.len()); - let encoded_input = to_vec(&input).expect("Could not serialize proving input!"); - - let result = maybe_prove::( - &config, - encoded_input, - RISC0_AGGREGATION_ELF, - &output.hash, - (assumptions, assumptions_uuids), - proof_key, - &mut id_store, - ) - .await; - - let receipt = result.clone().unwrap().1.clone(); - let uuid = result.clone().unwrap().0; - - let proof_gen_result = if result.is_some() { - if config.snark && config.bonsai { - let (stark_uuid, stark_receipt) = result.clone().unwrap(); - bonsai::bonsai_stark_to_snark(stark_uuid, stark_receipt, output.hash) - .await - .map(|r0_response| r0_response.into()) - .map_err(|e| ProverError::GuestError(e.to_string())) - } else { - warn!("proof is not in snark mode, please check."); - let (_, stark_receipt) = result.clone().unwrap(); - Ok(Risc0Response { - proof: stark_receipt.journal.encode_hex_with_prefix(), - receipt: serde_json::to_string(&receipt).unwrap(), - uuid, - input: output.hash, - } - .into()) + info!("Start aggregate proofs"); + // add_assumption makes the receipt to be verified available to the prover. + let env = { + let mut env = ExecutorEnv::builder(); + for assumption in assumptions { + env.add_assumption(assumption); } - } else { - Err(ProverError::GuestError( - "Failed to generate proof".to_string(), - )) + env.write(&input).unwrap().build().unwrap() }; + let opts = ProverOpts::groth16(); + let receipt = default_prover() + .prove_with_opts(env, RISC0_AGGREGATION_ELF, &opts) + .unwrap() + .receipt; + + info!( + "Generate aggregatino receipt journal: {:?}", + receipt.journal + ); + let aggregation_image_id = compute_image_id(RISC0_AGGREGATION_ELF).unwrap(); + let enc_proof = + snarks::verify_groth16_snark_from_receipt(aggregation_image_id, receipt.clone()) + .await + .map_err(|err| format!("Failed to verify SNARK: {err:?}"))?; + let snark_proof = format!("0x{}", hex::encode(enc_proof)); + + let proof_gen_result = Ok(Risc0Response { + proof: snark_proof, + receipt: serde_json::to_string(&receipt).unwrap(), + uuid: "".to_owned(), + input: B256::from_slice(&receipt.journal.digest().as_bytes()), + } + .into()); + #[cfg(feature = "bonsai-auto-scaling")] if config.bonsai { // shutdown bonsai diff --git a/provers/risc0/driver/src/methods/risc0_aggregation.rs b/provers/risc0/driver/src/methods/risc0_aggregation.rs index f3b1fe64..06ad39e2 100644 --- a/provers/risc0/driver/src/methods/risc0_aggregation.rs +++ b/provers/risc0/driver/src/methods/risc0_aggregation.rs @@ -1,5 +1,5 @@ pub const RISC0_AGGREGATION_ELF: &[u8] = include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-aggregation"); pub const RISC0_AGGREGATION_ID: [u32; 8] = [ - 834745027, 3860709824, 1052791454, 925104520, 3609882255, 551703375, 2495735124, 1897996989, + 440526723, 3767976668, 67051936, 881100330, 2605787818, 1152192925, 943988177, 1141581874, ]; diff --git a/provers/risc0/driver/src/methods/risc0_guest.rs b/provers/risc0/driver/src/methods/risc0_guest.rs index 19d5fdfd..15915265 100644 --- a/provers/risc0/driver/src/methods/risc0_guest.rs +++ b/provers/risc0/driver/src/methods/risc0_guest.rs @@ -1,5 +1,5 @@ pub const RISC0_GUEST_ELF: &[u8] = include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-guest"); pub const RISC0_GUEST_ID: [u32; 8] = [ - 2724640415, 1388818056, 2370444677, 1329173777, 2657825669, 1524407056, 1629931902, 314750851, + 2426111784, 2252773481, 4093155148, 2853313326, 836865213, 1159934005, 790932950, 229907112, ]; diff --git a/provers/risc0/driver/src/snarks.rs b/provers/risc0/driver/src/snarks.rs index 056a1e8c..e3e597a8 100644 --- a/provers/risc0/driver/src/snarks.rs +++ b/provers/risc0/driver/src/snarks.rs @@ -150,9 +150,31 @@ pub async fn stark2snark( Ok(snark_data) } -pub async fn verify_groth16_snark( +pub async fn verify_groth16_from_snark_receipt( image_id: Digest, snark_receipt: SnarkReceipt, +) -> Result> { + let seal = encode(snark_receipt.snark.to_vec())?; + let journal_digest = snark_receipt.journal.digest(); + let post_state_digest = snark_receipt.post_state_digest.digest(); + verify_groth16_snark_impl(image_id, seal, journal_digest, post_state_digest).await +} + +pub async fn verify_groth16_snark_from_receipt( + image_id: Digest, + receipt: Receipt, +) -> Result> { + let seal = receipt.inner.groth16().unwrap().seal.clone(); + let journal_digest = receipt.journal.digest(); + let post_state_digest = receipt.claim()?.as_value().unwrap().post.digest(); + verify_groth16_snark_impl(image_id, seal, journal_digest, post_state_digest).await +} + +pub async fn verify_groth16_snark_impl( + image_id: Digest, + seal: Vec, + journal_digest: Digest, + post_state_digest: Digest, ) -> Result> { let verifier_rpc_url = std::env::var("GROTH16_VERIFIER_RPC_URL").expect("env GROTH16_VERIFIER_RPC_URL"); @@ -167,19 +189,15 @@ pub async fn verify_groth16_snark( 500, )?); - let seal = encode(snark_receipt.snark.to_vec())?; - let journal_digest = snark_receipt.journal.digest(); + let enc_seal = encode(seal)?; tracing_info!("Verifying SNARK:"); - tracing_info!("Seal: {}", hex::encode(&seal)); + tracing_info!("Seal: {}", hex::encode(&enc_seal)); tracing_info!("Image ID: {}", hex::encode(image_id.as_bytes())); - tracing_info!( - "Post State Digest: {}", - hex::encode(&snark_receipt.post_state_digest) - ); + tracing_info!("Post State Digest: {}", hex::encode(&post_state_digest)); tracing_info!("Journal Digest: {}", hex::encode(journal_digest)); let verify_call_res = IRiscZeroVerifier::new(groth16_verifier_addr, http_client) .verify( - seal.clone().into(), + enc_seal.clone().into(), image_id.as_bytes().try_into().unwrap(), journal_digest.into(), ) @@ -191,10 +209,14 @@ pub async fn verify_groth16_snark( tracing_err!("SNARK verification failed: {:?}!", verify_call_res); } - Ok((seal, B256::from_slice(image_id.as_bytes())) + Ok(make_risc0_groth16_proof(enc_seal, image_id)) +} + +pub fn make_risc0_groth16_proof(seal: Vec, image_id: Digest) -> Vec { + (seal, B256::from_slice(image_id.as_bytes())) .abi_encode() .iter() .skip(32) .copied() - .collect()) + .collect() } diff --git a/provers/risc0/guest/src/aggregation.rs b/provers/risc0/guest/src/aggregation.rs index 3f65701e..240711d7 100644 --- a/provers/risc0/guest/src/aggregation.rs +++ b/provers/risc0/guest/src/aggregation.rs @@ -1,21 +1,27 @@ +//! Aggregates multiple block proofs #![no_main] harness::entrypoint!(main); -use risc0_zkvm::{serde, guest::env}; -use raiko_lib::protocol_instance::words_to_bytes_le; -use raiko_lib::protocol_instance::aggregation_output; -use raiko_lib::input::ZkAggregationGuestInput; -use raiko_lib::primitives::B256; -fn main() { +use risc0_zkvm::{guest::env, serde}; + +use raiko_lib::{ + input::ZkAggregationGuestInput, + primitives::B256, + protocol_instance::{aggregation_output, words_to_bytes_le}, +}; + +pub fn main() { // Read the aggregation input - let input: ZkAggregationGuestInput = env::read(); + let input = env::read::(); // Verify the proofs. for block_input in input.block_inputs.iter() { - // Verify that n has a known factorization. - env::verify(input.image_id, &serde::to_vec(&block_input).unwrap()).unwrap(); + env::verify(input.image_id, &serde::to_vec(block_input).unwrap()).unwrap(); } // The aggregation output - env::commit(&aggregation_output(B256::from(words_to_bytes_le(&input.image_id)), input.block_inputs)); + env::commit_slice(&aggregation_output( + B256::from(words_to_bytes_le(&input.image_id)), + input.block_inputs, + )); } diff --git a/provers/sgx/guest/src/one_shot.rs b/provers/sgx/guest/src/one_shot.rs index 18778b45..156f92f9 100644 --- a/provers/sgx/guest/src/one_shot.rs +++ b/provers/sgx/guest/src/one_shot.rs @@ -146,11 +146,11 @@ pub async fn one_shot(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> let sig = sign_message(&prev_privkey, pi_hash)?; // Create the proof for the onchain SGX verifier + // 4(id) + 20(new) + 65(sig) = 89 const SGX_PROOF_LEN: usize = 89; let mut proof = Vec::with_capacity(SGX_PROOF_LEN); proof.extend(args.sgx_instance_id.to_be_bytes()); proof.extend(new_instance); - proof.extend(new_instance); proof.extend(sig); let proof = hex::encode(proof); @@ -194,11 +194,11 @@ pub async fn aggregate(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> for proof in input.proofs.iter() { // TODO: verify protocol instance data so we can trust the old/new instance data assert_eq!( - recover_signer_unchecked(&proof.proof.clone()[44..].try_into().unwrap(), &proof.input,) + recover_signer_unchecked(&proof.proof.clone()[24..].try_into().unwrap(), &proof.input,) .unwrap(), cur_instance, ); - cur_instance = Address::from_slice(&proof.proof.clone()[24..44]); + cur_instance = Address::from_slice(&proof.proof.clone()[4..24]); } // Current public key needs to match latest proof new public key @@ -224,7 +224,8 @@ pub async fn aggregate(global_opts: GlobalOpts, args: OneShotArgs) -> Result<()> let sig = sign_message(&prev_privkey, aggregation_hash.into())?; // Create the proof for the onchain SGX verifier - const SGX_PROOF_LEN: usize = 89; + const SGX_PROOF_LEN: usize = 109; + // 4(id) + 20(old) + 20(new) + 65(sig) = 109 let mut proof = Vec::with_capacity(SGX_PROOF_LEN); proof.extend(args.sgx_instance_id.to_be_bytes()); proof.extend(old_instance); diff --git a/provers/sp1/driver/Cargo.toml b/provers/sp1/driver/Cargo.toml index aea4fdae..ef71e803 100644 --- a/provers/sp1/driver/Cargo.toml +++ b/provers/sp1/driver/Cargo.toml @@ -25,11 +25,11 @@ serde_json = { workspace = true, optional = true } sp1-sdk = { workspace = true, optional = true } anyhow = { workspace = true, optional = true } once_cell = { workspace = true, optional = true } -sha3 = { workspace = true, optional = true, default-features = false} +sha3 = { workspace = true, optional = true, default-features = false } serde_with = { workspace = true, optional = true } -dotenv = {workspace = true, optional = true} -cfg-if = {workspace = true } -bincode = {workspace = true } +dotenv = { workspace = true, optional = true } +cfg-if = { workspace = true } +bincode = { workspace = true } reth-primitives = { workspace = true } tokio = { workspace = true, optional = true } tracing = { workspace = true, optional = true } @@ -56,6 +56,6 @@ enable = [ "dotenv", "serde_with", "tokio", - "tracing" + "tracing", ] neon = ["sp1-sdk?/neon"] diff --git a/provers/sp1/driver/src/lib.rs b/provers/sp1/driver/src/lib.rs index c8f0fe60..f3f388e3 100644 --- a/provers/sp1/driver/src/lib.rs +++ b/provers/sp1/driver/src/lib.rs @@ -3,7 +3,10 @@ use once_cell::sync::Lazy; use raiko_lib::{ - input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput}, + input::{ + AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, + ZkAggregationGuestInput, + }, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, Measurement, }; @@ -14,6 +17,7 @@ use sp1_sdk::{ action, network::client::NetworkClient, proto::network::{ProofMode, UnclaimReason}, + SP1Proof, SP1ProofWithPublicValues, SP1VerifyingKey, }; use sp1_sdk::{HashableKey, ProverClient, SP1Stdin}; use std::{ @@ -21,7 +25,7 @@ use std::{ env, fs, path::{Path, PathBuf}, }; -use tracing::{debug, info}; +use tracing::{debug, error, info}; pub const ELF: &[u8] = include_bytes!("../../guest/elf/sp1-guest"); pub const AGGREGATION_ELF: &[u8] = include_bytes!("../../guest/elf/sp1-aggregation"); @@ -76,9 +80,15 @@ impl From for Proof { fn from(value: Sp1Response) -> Self { Self { proof: value.proof, - quote: None, - input: None, - uuid: None, + quote: value + .sp1_proof + .as_ref() + .map(|p| serde_json::to_string(&p.proof).unwrap()), + input: value + .sp1_proof + .as_ref() + .map(|p| B256::from_slice(p.public_values.as_slice())), + uuid: value.vkey.map(|v| serde_json::to_string(&v).unwrap()), kzg_proof: None, } } @@ -87,6 +97,9 @@ impl From for Proof { #[derive(Clone, Serialize, Deserialize)] pub struct Sp1Response { pub proof: Option, + /// for aggregation + pub sp1_proof: Option, + pub vkey: Option, } pub struct Sp1Prover; @@ -131,8 +144,7 @@ impl Prover for Sp1Prover { RecursionMode::Compressed => prove_action.compressed().run(), RecursionMode::Plonk => prove_action.plonk().run(), } - .map_err(|e| ProverError::GuestError(format!("Sp1: local proving failed: {}", e))) - .unwrap() + .map_err(|e| ProverError::GuestError(format!("Sp1: local proving failed: {e}")))? } else { let network_prover = sp1_sdk::NetworkProver::new(); @@ -151,17 +163,22 @@ impl Prover for Sp1Prover { .await?; } info!( - "Sp1 Prover: block {:?} - proof id {:?}", - output.header.number, proof_id + "Sp1 Prover: block {:?} - proof id {proof_id:?}", + output.header.number ); network_prover .wait_proof::(&proof_id, None) .await - .map_err(|e| ProverError::GuestError(format!("Sp1: network proof failed {:?}", e))) - .unwrap() + .map_err(|e| ProverError::GuestError(format!("Sp1: network proof failed {e:?}")))? }; - let proof_bytes = prove_result.bytes(); + let proof_bytes = match param.recursion { + RecursionMode::Compressed => { + info!("Compressed proof is used in aggregation mode only"); + vec![] + } + _ => prove_result.bytes(), + }; if param.verify { let time = Measurement::start("verify", false); let pi_hash = prove_result @@ -179,17 +196,15 @@ impl Prover for Sp1Prover { time.stop_with("==> Verification complete"); } - let proof_string = if proof_bytes.is_empty() { - None - } else { + let proof_string = (!proof_bytes.is_empty()).then_some( // 0x + 64 bytes of the vkey + the proof // vkey itself contains 0x prefix - Some(format!( + format!( "{}{}", vk.bytes32(), reth_primitives::hex::encode(proof_bytes) - )) - }; + ), + ); info!( "Sp1 Prover: block {:?} completed! proof: {proof_string:?}", @@ -198,6 +213,8 @@ impl Prover for Sp1Prover { Ok::<_, ProverError>( Sp1Response { proof: proof_string, + sp1_proof: Some(prove_result), + vkey: Some(vk), } .into(), ) @@ -227,12 +244,109 @@ impl Prover for Sp1Prover { } async fn aggregate( - _input: AggregationGuestInput, + input: AggregationGuestInput, _output: &AggregationGuestOutput, - _config: &ProverConfig, + config: &ProverConfig, _store: Option<&mut dyn IdWrite>, ) -> ProverResult { - todo!() + let param = Sp1Param::deserialize(config.get("sp1").unwrap()).unwrap(); + let mode = param.prover.clone().unwrap_or_else(get_env_mock); + + info!("aggregate proof with param: {param:?}"); + + let block_inputs: Vec = input + .proofs + .iter() + .map(|proof| proof.input.unwrap()) + .collect::>(); + let block_proof_vk = serde_json::from_str::( + &input.proofs.first().unwrap().uuid.clone().unwrap(), + ) + .map_err(|e| ProverError::GuestError(format!("Failed to parse SP1 vk: {e}")))?; + let stark_vk = block_proof_vk.vk.clone(); + let image_id = block_proof_vk.hash_u32(); + let aggregation_input = ZkAggregationGuestInput { + image_id: image_id, + block_inputs, + }; + info!( + "Aggregating {:?} proofs with input: {:?}", + input.proofs.len(), + aggregation_input + ); + + let mut stdin = SP1Stdin::new(); + stdin.write(&aggregation_input); + for proof in input.proofs.iter() { + let sp1_proof = serde_json::from_str::(&proof.quote.clone().unwrap()) + .map_err(|e| ProverError::GuestError(format!("Failed to parse SP1 proof: {e}")))?; + match sp1_proof { + SP1Proof::Compressed(block_proof) => { + stdin.write_proof(block_proof.into(), stark_vk.clone()); + } + _ => { + error!("unsupported proof type for aggregation: {:?}", sp1_proof); + } + } + } + + // Generate the proof for the given program. + let client = param + .prover + .map(|mode| match mode { + ProverMode::Mock => ProverClient::mock(), + ProverMode::Local => ProverClient::local(), + ProverMode::Network => ProverClient::network(), + }) + .unwrap_or_else(ProverClient::new); + + let (pk, vk) = client.setup(AGGREGATION_ELF); + info!( + "sp1 aggregate: {:?} based {:?} blocks with vk {:?}", + reth_primitives::hex::encode_prefixed(stark_vk.hash_bytes()), + input.proofs.len(), + vk.bytes32() + ); + + let prove_result = client + .prove(&pk, stdin) + .plonk() + .run() + .expect("proving failed"); + + let proof_bytes = prove_result.bytes(); + if param.verify { + let time = Measurement::start("verify", false); + let aggregation_pi = prove_result.clone().borrow_mut().public_values.raw(); + let fixture = RaikoProofFixture { + vkey: vk.bytes32().to_string(), + public_values: reth_primitives::hex::encode_prefixed(&aggregation_pi), + proof: reth_primitives::hex::encode_prefixed(&proof_bytes), + }; + + verify_sol(&fixture)?; + time.stop_with("==> Verification complete"); + } + + let proof = (!proof_bytes.is_empty()).then_some( + // 0x + 64 bytes of the vkey + the proof + // vkey itself contains 0x prefix + format!( + "{}{}{}", + vk.bytes32(), + reth_primitives::hex::encode(stark_vk.hash_bytes()), + reth_primitives::hex::encode(proof_bytes) + ), + ); + + Ok::<_, ProverError>( + Sp1Response { + proof: proof, + sp1_proof: None, + vkey: None, + } + .into(), + ) } } diff --git a/provers/sp1/guest/Cargo.lock b/provers/sp1/guest/Cargo.lock index eab71cc4..cfa9f3a4 100644 --- a/provers/sp1/guest/Cargo.lock +++ b/provers/sp1/guest/Cargo.lock @@ -3504,7 +3504,7 @@ dependencies = [ "size", "snowbridge-amcl", "sp1-derive", - "sp1-primitives", + "sp1-primitives 1.1.1", "static_assertions", "strum", "strum_macros", @@ -3549,8 +3549,9 @@ dependencies = [ [[package]] name = "sp1-lib" -version = "1.2.0-rc1" -source = "git+https://github.com/succinctlabs/sp1?branch=dev#e8efd0019c8be52c6c4cecfea6259ab90db4148a" +version = "1.2.0-rc2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b85660c40c7b40a65c706816d9157ef1b084099a80275c9b4d650f53067e667f" dependencies = [ "anyhow", "bincode", @@ -3562,9 +3563,9 @@ dependencies = [ [[package]] name = "sp1-lib" -version = "1.2.0-rc2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b85660c40c7b40a65c706816d9157ef1b084099a80275c9b4d650f53067e667f" +checksum = "413956de14568d7fb462213b9505ad4607d75c875301b9eca567cfb2e58eaac1" dependencies = [ "anyhow", "bincode", @@ -3588,10 +3589,25 @@ dependencies = [ "p3-symmetric", ] +[[package]] +name = "sp1-primitives" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efbeba375fe59917f162f1808c280d2e39e4698dc7eeac83936b6e70c2f8dbbc" +dependencies = [ + "itertools 0.13.0", + "lazy_static", + "p3-baby-bear", + "p3-field", + "p3-poseidon2", + "p3-symmetric", +] + [[package]] name = "sp1-zkvm" -version = "1.2.0-rc1" -source = "git+https://github.com/succinctlabs/sp1?branch=dev#e8efd0019c8be52c6c4cecfea6259ab90db4148a" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66c525f67cfd3f65950f01c713a72c41a5d44d289155644c8ace4ec264098039" dependencies = [ "bincode", "cfg-if", @@ -3604,7 +3620,8 @@ dependencies = [ "rand", "serde", "sha2", - "sp1-lib 1.2.0-rc1", + "sp1-lib 2.0.0", + "sp1-primitives 2.0.0", ] [[package]] diff --git a/provers/sp1/guest/Cargo.toml b/provers/sp1/guest/Cargo.toml index 21856d7f..3063cc5b 100644 --- a/provers/sp1/guest/Cargo.toml +++ b/provers/sp1/guest/Cargo.toml @@ -37,9 +37,9 @@ path = "src/benchmark/bn254_mul.rs" [dependencies] raiko-lib = { path = "../../../lib", features = ["std", "sp1"] } -sp1-zkvm = { git = "https://github.com/succinctlabs/sp1", branch = "dev" } -sp1-core = { version = "1.1.1"} -sha2-v0-10-8 = { git = "https://github.com/sp1-patches/RustCrypto-hashes", package = "sha2", branch = "patch-v0.10.8" } +sp1-zkvm = { version = "2.0.0", features = ["verify"] } +sp1-core = { version = "1.1.1" } +sha2 = { git = "https://github.com/sp1-patches/RustCrypto-hashes", package = "sha2", branch = "patch-v0.10.8" } secp256k1 = { git = "https://github.com/sp1-patches/rust-secp256k1", branch = "patch-secp256k1-v0.29.0" } harness-core = { path = "../../../harness/core" } harness = { path = "../../../harness/macro", features = ["sp1"] } @@ -50,7 +50,10 @@ revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36- "c-kzg", ] } bincode = "1.3.3" -reth-primitives = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false, features = ["alloy-compat", "taiko"] } +reth-primitives = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false, features = [ + "alloy-compat", + "taiko", +] } lazy_static = "1.4.0" num-bigint = { version = "0.4.6", default-features = false } diff --git a/provers/sp1/guest/elf/sp1-aggregation b/provers/sp1/guest/elf/sp1-aggregation index 0dbad9cd..ed3c2c31 100755 Binary files a/provers/sp1/guest/elf/sp1-aggregation and b/provers/sp1/guest/elf/sp1-aggregation differ diff --git a/provers/sp1/guest/elf/sp1-guest b/provers/sp1/guest/elf/sp1-guest index 5d263346..6cca8898 100755 Binary files a/provers/sp1/guest/elf/sp1-guest and b/provers/sp1/guest/elf/sp1-guest differ diff --git a/provers/sp1/guest/src/aggregation.rs b/provers/sp1/guest/src/aggregation.rs index b69a50bc..84d4bde3 100644 --- a/provers/sp1/guest/src/aggregation.rs +++ b/provers/sp1/guest/src/aggregation.rs @@ -1,15 +1,14 @@ //! Aggregates multiple block proofs - #![no_main] sp1_zkvm::entrypoint!(main); -use sha2::Sha256; -use sha2::Digest; +use sha2::{Digest, Sha256}; -use raiko_lib::protocol_instance::words_to_bytes_le; -use raiko_lib::protocol_instance::aggregation_output; -use raiko_lib::input::ZkAggregationGuestInput; -use raiko_lib::primitives::B256; +use raiko_lib::{ + input::ZkAggregationGuestInput, + primitives::B256, + protocol_instance::{aggregation_output, words_to_bytes_be}, +}; pub fn main() { // Read the aggregation input @@ -17,9 +16,16 @@ pub fn main() { // Verify the block proofs. for block_input in input.block_inputs.iter() { - sp1_zkvm::lib::verify::verify_sp1_proof(&input.image_id, &Sha256::digest(block_input).into()); + sp1_zkvm::lib::verify::verify_sp1_proof( + &input.image_id, + &Sha256::digest(block_input).into(), + ); } // The aggregation output - sp1_zkvm::io::commit_slice(&aggregation_output(B256::from(words_to_bytes_le(&input.image_id)), input.block_inputs)); -} \ No newline at end of file + sp1_zkvm::io::commit_slice(&aggregation_output( + B256::from(words_to_bytes_be(&input.image_id)), + input.block_inputs, + )); +} + diff --git a/provers/sp1/guest/src/benchmark/bn254_add.rs b/provers/sp1/guest/src/benchmark/bn254_add.rs index 096b6546..1f572963 100644 --- a/provers/sp1/guest/src/benchmark/bn254_add.rs +++ b/provers/sp1/guest/src/benchmark/bn254_add.rs @@ -17,11 +17,11 @@ fn main() { ]); let op = Sp1Operator {}; - + let ct = CycleTracker::start("bn128_run_add"); let res = op.bn128_run_add(&input).unwrap(); ct.end(); - + let hi = res[..32].to_vec(); let lo = res[32..].to_vec(); diff --git a/provers/sp1/guest/src/benchmark/bn254_mul.rs b/provers/sp1/guest/src/benchmark/bn254_mul.rs index 664947de..ae1ede10 100644 --- a/provers/sp1/guest/src/benchmark/bn254_mul.rs +++ b/provers/sp1/guest/src/benchmark/bn254_mul.rs @@ -19,7 +19,7 @@ fn main() { let ct = CycleTracker::start("bn128_run_mul"); let res = op.bn128_run_mul(&input).unwrap(); ct.end(); - + let hi = res[..32].to_vec(); let lo = res[32..].to_vec(); sp1_zkvm::io::commit(&hi); diff --git a/provers/sp1/guest/src/benchmark/sha256.rs b/provers/sp1/guest/src/benchmark/sha256.rs index 9c5908b1..e6c57433 100644 --- a/provers/sp1/guest/src/benchmark/sha256.rs +++ b/provers/sp1/guest/src/benchmark/sha256.rs @@ -13,7 +13,7 @@ fn main() { ]); let op = Sp1Operator {}; - + let ct = CycleTracker::start("sha256_run"); let res = op.sha256_run(&input).unwrap(); ct.end(); diff --git a/provers/sp1/guest/src/sys.rs b/provers/sp1/guest/src/sys.rs index 04a3c18d..f9eed1c9 100644 --- a/provers/sp1/guest/src/sys.rs +++ b/provers/sp1/guest/src/sys.rs @@ -39,4 +39,5 @@ pub unsafe extern "C" fn free(_size: *const c_void) { #[no_mangle] pub extern "C" fn __ctzsi2(x: u32) -> u32 { x.trailing_zeros() -} \ No newline at end of file +} + diff --git a/provers/sp1/guest/src/zk_op.rs b/provers/sp1/guest/src/zk_op.rs index 9fad10a9..71330b1f 100644 --- a/provers/sp1/guest/src/zk_op.rs +++ b/provers/sp1/guest/src/zk_op.rs @@ -1,5 +1,5 @@ -use num_bigint::BigUint; use ::secp256k1::SECP256K1; +use num_bigint::BigUint; use reth_primitives::public_key_to_address; use revm_precompile::{bn128::ADD_INPUT_LEN, utilities::right_pad, zk_op::ZkvmOperator, Error}; use secp256k1::{ @@ -9,7 +9,6 @@ use secp256k1::{ use sha2 as sp1_sha2; use sp1_core::utils::ec::{weierstrass::bn254::Bn254, AffinePoint}; - #[derive(Debug)] pub struct Sp1Operator; @@ -154,4 +153,5 @@ harness::zk_suits!( assert!(G1_LE == [p.x.to_bytes_le(), p.y.to_bytes_le()].concat()); } } -); \ No newline at end of file +); + diff --git a/script/prove-block.sh b/script/prove-block.sh index 7b0d387e..8e3113e8 100755 --- a/script/prove-block.sh +++ b/script/prove-block.sh @@ -58,6 +58,16 @@ elif [ "$proof" == "sp1" ]; then "verify": false } ' +elif [ "$proof" == "sp1-aggregation" ]; then + proofParam=' + "proof_type": "sp1", + "blob_proof_type": "proof_of_equivalence", + "sp1": { + "recursion": "compressed", + "prover": "network", + "verify": false + } + ' elif [ "$proof" == "sgx" ]; then proofParam=' "proof_type": "sgx", @@ -134,13 +144,13 @@ for block in $(eval echo {$rangeStart..$rangeEnd}); do fi echo "- proving block $block" - curl --location --request POST 'http://localhost:8080/proof' \ + curl --location --request POST 'http://localhost:8080/v3/proof' \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer 4cbd753fbcbc2639de804f8ce425016a50e0ecd53db00cb5397912e83f5e570e' \ --data-raw "{ \"network\": \"$chain\", \"l1_network\": \"$l1_network\", - \"block_number\": $block, + \"block_numbers\": [[$block, null], [$(($block+1)), null]], \"prover\": \"$prover\", \"graffiti\": \"$graffiti\", $proofParam diff --git a/tasks/src/adv_sqlite.rs b/tasks/src/adv_sqlite.rs index 120c0d43..96f9e4bb 100644 --- a/tasks/src/adv_sqlite.rs +++ b/tasks/src/adv_sqlite.rs @@ -159,6 +159,7 @@ use std::{ }; use chrono::{DateTime, Utc}; +use raiko_core::interfaces::AggregationOnlyRequest; use raiko_lib::{ primitives::B256, prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}, @@ -575,7 +576,7 @@ impl TaskDb { ":blockhash": blockhash.to_vec(), ":proofsys_id": proof_system as u8, ":prover": prover, - ":status_id": status as i32, + ":status_id": i32::from(status), ":proof": proof.map(hex::encode) })?; @@ -943,6 +944,36 @@ impl TaskManager for SqliteTaskManager { let task_db = self.arc_task_db.lock().await; task_db.list_stored_ids() } + + async fn enqueue_aggregation_task( + &mut self, + _request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + todo!() + } + + async fn get_aggregation_task_proving_status( + &mut self, + _request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + todo!() + } + + async fn update_aggregation_task_progress( + &mut self, + _request: &AggregationOnlyRequest, + _status: TaskStatus, + _proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + todo!() + } + + async fn get_aggregation_task_proof( + &mut self, + _request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + todo!() + } } #[cfg(test)] diff --git a/tasks/src/lib.rs b/tasks/src/lib.rs index 2abd2e74..cc7523e3 100644 --- a/tasks/src/lib.rs +++ b/tasks/src/lib.rs @@ -4,8 +4,7 @@ use std::{ }; use chrono::{DateTime, Utc}; -use num_enum::{FromPrimitive, IntoPrimitive}; -use raiko_core::interfaces::ProofType; +use raiko_core::interfaces::{AggregationOnlyRequest, ProofType}; use raiko_lib::{ primitives::{ChainId, B256}, prover::{IdStore, IdWrite, ProofKey, ProverResult}, @@ -61,24 +60,83 @@ impl From for TaskManagerError { #[allow(non_camel_case_types)] #[rustfmt::skip] -#[derive(PartialEq, Debug, Copy, Clone, IntoPrimitive, FromPrimitive, Deserialize, Serialize, ToSchema)] -#[repr(i32)] +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, ToSchema, Eq, PartialOrd, Ord)] #[serde(rename_all = "snake_case")] pub enum TaskStatus { - Success = 0, - Registered = 1000, - WorkInProgress = 2000, - ProofFailure_Generic = -1000, - ProofFailure_OutOfMemory = -1100, - NetworkFailure = -2000, - Cancelled = -3000, - Cancelled_NeverStarted = -3100, - Cancelled_Aborted = -3200, - CancellationInProgress = -3210, - InvalidOrUnsupportedBlock = -4000, - UnspecifiedFailureReason = -9999, - #[num_enum(default)] - SqlDbCorruption = -99999, + Success, + Registered, + WorkInProgress, + ProofFailure_Generic, + ProofFailure_OutOfMemory, + NetworkFailure, + Cancelled, + Cancelled_NeverStarted, + Cancelled_Aborted, + CancellationInProgress, + InvalidOrUnsupportedBlock, + NonDbFailure(String), + UnspecifiedFailureReason, + SqlDbCorruption, +} + +impl From for i32 { + fn from(status: TaskStatus) -> i32 { + match status { + TaskStatus::Success => 0, + TaskStatus::Registered => 1000, + TaskStatus::WorkInProgress => 2000, + TaskStatus::ProofFailure_Generic => -1000, + TaskStatus::ProofFailure_OutOfMemory => -1100, + TaskStatus::NetworkFailure => -2000, + TaskStatus::Cancelled => -3000, + TaskStatus::Cancelled_NeverStarted => -3100, + TaskStatus::Cancelled_Aborted => -3200, + TaskStatus::CancellationInProgress => -3210, + TaskStatus::InvalidOrUnsupportedBlock => -4000, + TaskStatus::NonDbFailure(_) => -5000, + TaskStatus::UnspecifiedFailureReason => -9999, + TaskStatus::SqlDbCorruption => -99999, + } + } +} + +impl From for TaskStatus { + fn from(value: i32) -> TaskStatus { + match value { + 0 => TaskStatus::Success, + 1000 => TaskStatus::Registered, + 2000 => TaskStatus::WorkInProgress, + -1000 => TaskStatus::ProofFailure_Generic, + -1100 => TaskStatus::ProofFailure_OutOfMemory, + -2000 => TaskStatus::NetworkFailure, + -3000 => TaskStatus::Cancelled, + -3100 => TaskStatus::Cancelled_NeverStarted, + -3200 => TaskStatus::Cancelled_Aborted, + -3210 => TaskStatus::CancellationInProgress, + -4000 => TaskStatus::InvalidOrUnsupportedBlock, + -5000 => TaskStatus::NonDbFailure("".to_string()), + -9999 => TaskStatus::UnspecifiedFailureReason, + -99999 => TaskStatus::SqlDbCorruption, + _ => TaskStatus::UnspecifiedFailureReason, + } + } +} + +impl FromIterator for TaskStatus { + fn from_iter>(iter: T) -> Self { + iter.into_iter() + .min() + .unwrap_or(TaskStatus::UnspecifiedFailureReason) + } +} + +impl<'a> FromIterator<&'a TaskStatus> for TaskStatus { + fn from_iter>(iter: T) -> Self { + iter.into_iter() + .min() + .cloned() + .unwrap_or(TaskStatus::UnspecifiedFailureReason) + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -166,6 +224,32 @@ pub trait TaskManager: IdStore + IdWrite { /// List all stored ids. async fn list_stored_ids(&mut self) -> TaskManagerResult>; + + /// Enqueue a new aggregation task to the tasks database. + async fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()>; + + /// Update a specific aggregation tasks progress. + async fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()>; + + /// Returns the latest triplet (status, proof - if any, last update time). + async fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult; + + /// Returns the proof for the given aggregation task. + async fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult>; } pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> { @@ -297,6 +381,68 @@ impl TaskManager for TaskManagerWrapper { TaskManagerInstance::Sqlite(manager) => manager.list_stored_ids().await, } } + + async fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.enqueue_aggregation_task(request).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.enqueue_aggregation_task(request).await + } + } + } + + async fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager + .update_aggregation_task_progress(request, status, proof) + .await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager + .update_aggregation_task_progress(request, status, proof) + .await + } + } + } + + async fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.get_aggregation_task_proving_status(request).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.get_aggregation_task_proving_status(request).await + } + } + } + + async fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.get_aggregation_task_proof(request).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.get_aggregation_task_proof(request).await + } + } + } } pub fn get_task_manager(opts: &TaskManagerOpts) -> TaskManagerWrapper { diff --git a/tasks/src/mem_db.rs b/tasks/src/mem_db.rs index ad655000..f3bee788 100644 --- a/tasks/src/mem_db.rs +++ b/tasks/src/mem_db.rs @@ -13,6 +13,7 @@ use std::{ }; use chrono::Utc; +use raiko_core::interfaces::AggregationOnlyRequest; use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; use tokio::sync::Mutex; use tracing::{debug, info}; @@ -29,14 +30,16 @@ pub struct InMemoryTaskManager { #[derive(Debug)] pub struct InMemoryTaskDb { - enqueue_task: HashMap, + tasks_queue: HashMap, + aggregation_tasks_queue: HashMap, store: HashMap, } impl InMemoryTaskDb { fn new() -> InMemoryTaskDb { InMemoryTaskDb { - enqueue_task: HashMap::new(), + tasks_queue: HashMap::new(), + aggregation_tasks_queue: HashMap::new(), store: HashMap::new(), } } @@ -44,7 +47,7 @@ impl InMemoryTaskDb { fn enqueue_task(&mut self, key: &TaskDescriptor) { let task_status = (TaskStatus::Registered, None, Utc::now()); - match self.enqueue_task.get(key) { + match self.tasks_queue.get(key) { Some(task_proving_records) => { debug!( "Task already exists: {:?}", @@ -53,7 +56,7 @@ impl InMemoryTaskDb { } // do nothing None => { info!("Enqueue new task: {key:?}"); - self.enqueue_task.insert(key.clone(), vec![task_status]); + self.tasks_queue.insert(key.clone(), vec![task_status]); } } } @@ -64,9 +67,9 @@ impl InMemoryTaskDb { status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()> { - ensure(self.enqueue_task.contains_key(&key), "no task found")?; + ensure(self.tasks_queue.contains_key(&key), "no task found")?; - self.enqueue_task.entry(key).and_modify(|entry| { + self.tasks_queue.entry(key).and_modify(|entry| { if let Some(latest) = entry.last() { if latest.0 != status { entry.push((status, proof.map(hex::encode), Utc::now())); @@ -81,14 +84,14 @@ impl InMemoryTaskDb { &mut self, key: &TaskDescriptor, ) -> TaskManagerResult { - Ok(self.enqueue_task.get(key).cloned().unwrap_or_default()) + Ok(self.tasks_queue.get(key).cloned().unwrap_or_default()) } fn get_task_proof(&mut self, key: &TaskDescriptor) -> TaskManagerResult> { - ensure(self.enqueue_task.contains_key(key), "no task found")?; + ensure(self.tasks_queue.contains_key(key), "no task found")?; let proving_status_records = self - .enqueue_task + .tasks_queue .get(key) .ok_or_else(|| TaskManagerError::SqlError("no task in db".to_owned()))?; @@ -107,20 +110,22 @@ impl InMemoryTaskDb { } fn size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - Ok((self.enqueue_task.len(), vec![])) + Ok((self.tasks_queue.len(), vec![])) } fn prune(&mut self) -> TaskManagerResult<()> { - self.enqueue_task.clear(); + self.tasks_queue.clear(); Ok(()) } fn list_all_tasks(&mut self) -> TaskManagerResult> { Ok(self - .enqueue_task + .tasks_queue .iter() .flat_map(|(descriptor, statuses)| { - statuses.last().map(|status| (descriptor.clone(), status.0)) + statuses + .last() + .map(|status| (descriptor.clone(), status.0.clone())) }) .collect()) } @@ -145,6 +150,91 @@ impl InMemoryTaskDb { .cloned() .ok_or(TaskManagerError::NoData) } + + fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + let task_status = (TaskStatus::Registered, None, Utc::now()); + + match self.aggregation_tasks_queue.get(request) { + Some(task_proving_records) => { + debug!( + "Task already exists: {:?}", + task_proving_records.last().unwrap().0 + ); + } // do nothing + None => { + info!("Enqueue new task: {request}"); + self.aggregation_tasks_queue + .insert(request.clone(), vec![task_status]); + } + } + Ok(()) + } + + fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + Ok(self + .aggregation_tasks_queue + .get(request) + .cloned() + .unwrap_or_default()) + } + + fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + ensure( + self.aggregation_tasks_queue.contains_key(request), + "no task found", + )?; + + self.aggregation_tasks_queue + .entry(request.clone()) + .and_modify(|entry| { + if let Some(latest) = entry.last() { + if latest.0 != status { + entry.push((status, proof.map(hex::encode), Utc::now())); + } + } + }); + + Ok(()) + } + + fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + ensure( + self.aggregation_tasks_queue.contains_key(request), + "no task found", + )?; + + let proving_status_records = self + .aggregation_tasks_queue + .get(request) + .ok_or_else(|| TaskManagerError::SqlError("no task in db".to_owned()))?; + + let (_, proof, ..) = proving_status_records + .iter() + .filter(|(status, ..)| (status == &TaskStatus::Success)) + .last() + .ok_or_else(|| TaskManagerError::SqlError("no successful task in db".to_owned()))?; + + let Some(proof) = proof else { + return Ok(vec![]); + }; + + hex::decode(proof) + .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) + } } #[async_trait::async_trait] @@ -248,6 +338,40 @@ impl TaskManager for InMemoryTaskManager { let mut db = self.db.lock().await; db.list_stored_ids() } + + async fn enqueue_aggregation_task( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult<()> { + let mut db = self.db.lock().await; + db.enqueue_aggregation_task(request) + } + + async fn get_aggregation_task_proving_status( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + let mut db = self.db.lock().await; + db.get_aggregation_task_proving_status(request) + } + + async fn update_aggregation_task_progress( + &mut self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + let mut db = self.db.lock().await; + db.update_aggregation_task_progress(request, status, proof) + } + + async fn get_aggregation_task_proof( + &mut self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + let mut db = self.db.lock().await; + db.get_aggregation_task_proof(request) + } } #[cfg(test)]