diff --git a/bin/cairo-prove/src/fetch.rs b/bin/cairo-prove/src/fetch.rs index 9879f13..229173b 100644 --- a/bin/cairo-prove/src/fetch.rs +++ b/bin/cairo-prove/src/fetch.rs @@ -1,6 +1,10 @@ +use std::time::Duration; + use prover_sdk::sdk::ProverSDK; use serde::Deserialize; use serde_json::Value; +use tokio::time::sleep; +use tracing::info; use crate::errors::ProveErrors; @@ -9,10 +13,11 @@ pub struct JobId { pub job_id: u64, } -pub async fn fetch_job(sdk: ProverSDK, job: String) -> Result { +pub async fn fetch_job_sse(sdk: ProverSDK, job: String) -> Result { let job: JobId = serde_json::from_str(&job)?; - println!("Job ID: {}", job.job_id); + info!("Job ID: {}", job.job_id); sdk.sse(job.job_id).await?; + info!("Job completed"); let response = sdk.get_job(job.job_id).await?; let response = response.text().await?; let json_response: Value = serde_json::from_str(&response)?; @@ -30,3 +35,36 @@ pub async fn fetch_job(sdk: ProverSDK, job: String) -> Result Result { + let job: JobId = serde_json::from_str(&job)?; + info!("Fetching job: {}", job.job_id); + let mut counter = 0; + loop { + let response = sdk.get_job(job.job_id).await?; + let response = response.text().await?; + let json_response: Value = serde_json::from_str(&response)?; + if let Some(status) = json_response.get("status").and_then(Value::as_str) { + match status { + "Completed" => { + return Ok(json_response + .get("result") + .and_then(Value::as_str) + .unwrap_or("No result found") + .to_string()); + } + "Pending" | "Running" => { + info!("Job is still in progress. Status: {}", status); + info!( + "Time passed: {} Waiting for 10 seconds before retrying...", + counter * 10 + ); + counter += 1; + sleep(Duration::from_secs(10)).await; + } + _ => { + return Err(ProveErrors::Custom(json_response.to_string())); + } + } + } + } +} diff --git a/bin/cairo-prove/src/lib.rs b/bin/cairo-prove/src/lib.rs index d706e39..4f0ef05 100644 --- a/bin/cairo-prove/src/lib.rs +++ b/bin/cairo-prove/src/lib.rs @@ -53,6 +53,8 @@ pub struct Args { pub prover_access_key: String, #[arg(long, env, default_value = "false")] pub wait: bool, + #[arg(long, env, default_value = "false")] + pub sse: bool, } fn validate_input(input: &str) -> Result, ProveErrors> { diff --git a/bin/cairo-prove/src/main.rs b/bin/cairo-prove/src/main.rs index f99c626..91c5c02 100644 --- a/bin/cairo-prove/src/main.rs +++ b/bin/cairo-prove/src/main.rs @@ -1,6 +1,9 @@ use cairo_prove::errors::ProveErrors; use cairo_prove::prove::prove; -use cairo_prove::{fetch::fetch_job, Args}; +use cairo_prove::{ + fetch::{fetch_job_polling, fetch_job_sse}, + Args, +}; use clap::Parser; use prover_sdk::access_key::ProverAccessKey; use prover_sdk::sdk::ProverSDK; @@ -12,9 +15,14 @@ pub async fn main() -> Result<(), ProveErrors> { let sdk = ProverSDK::new(args.prover_url.clone(), access_key).await?; let job = prove(args.clone(), sdk.clone()).await?; if args.wait { - let job = fetch_job(sdk, job).await?; + let job = if args.sse { + fetch_job_sse(sdk, job).await? + } else { + fetch_job_polling(sdk, job).await? + }; let path: std::path::PathBuf = args.program_output; std::fs::write(path, job)?; } + Ok(()) }