Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sse flag and poll as default #42

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions bin/cairo-prove/src/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::time::Duration;

use prover_sdk::sdk::ProverSDK;
use serde::Deserialize;
use serde_json::Value;
use tokio::time::sleep;
use tracing::info;

use crate::errors::ProveErrors;

Expand All @@ -9,10 +13,11 @@ pub struct JobId {
pub job_id: u64,
}

pub async fn fetch_job(sdk: ProverSDK, job: String) -> Result<String, ProveErrors> {
pub async fn fetch_job_sse(sdk: ProverSDK, job: String) -> Result<String, ProveErrors> {
let job: JobId = serde_json::from_str(&job)?;
println!("Job ID: {}", job.job_id);
info!("Job ID: {}", job.job_id);
sdk.sse(job.job_id).await?;
info!("Job completed");
let response = sdk.get_job(job.job_id).await?;
let response = response.text().await?;
let json_response: Value = serde_json::from_str(&response)?;
Expand All @@ -30,3 +35,36 @@ pub async fn fetch_job(sdk: ProverSDK, job: String) -> Result<String, ProveError
Err(ProveErrors::Custom(json_response.to_string()))
}
}
pub async fn fetch_job_polling(sdk: ProverSDK, job: String) -> Result<String, ProveErrors> {
let job: JobId = serde_json::from_str(&job)?;
info!("Fetching job: {}", job.job_id);
let mut counter = 0;
loop {
let response = sdk.get_job(job.job_id).await?;
let response = response.text().await?;
let json_response: Value = serde_json::from_str(&response)?;
if let Some(status) = json_response.get("status").and_then(Value::as_str) {
match status {
"Completed" => {
return Ok(json_response
.get("result")
.and_then(Value::as_str)
.unwrap_or("No result found")
.to_string());
}
"Pending" | "Running" => {
info!("Job is still in progress. Status: {}", status);
info!(
"Time passed: {} Waiting for 10 seconds before retrying...",
counter * 10
);
counter += 1;
sleep(Duration::from_secs(10)).await;
}
_ => {
return Err(ProveErrors::Custom(json_response.to_string()));
}
}
}
}
}
2 changes: 2 additions & 0 deletions bin/cairo-prove/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub struct Args {
pub prover_access_key: String,
#[arg(long, env, default_value = "false")]
pub wait: bool,
#[arg(long, env, default_value = "false")]
pub sse: bool,
}

fn validate_input(input: &str) -> Result<Vec<Felt>, ProveErrors> {
Expand Down
12 changes: 10 additions & 2 deletions bin/cairo-prove/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use cairo_prove::errors::ProveErrors;
use cairo_prove::prove::prove;
use cairo_prove::{fetch::fetch_job, Args};
use cairo_prove::{
fetch::{fetch_job_polling, fetch_job_sse},
Args,
};
use clap::Parser;
use prover_sdk::access_key::ProverAccessKey;
use prover_sdk::sdk::ProverSDK;
Expand All @@ -12,9 +15,14 @@ pub async fn main() -> Result<(), ProveErrors> {
let sdk = ProverSDK::new(args.prover_url.clone(), access_key).await?;
let job = prove(args.clone(), sdk.clone()).await?;
if args.wait {
let job = fetch_job(sdk, job).await?;
let job = if args.sse {
fetch_job_sse(sdk, job).await?
} else {
fetch_job_polling(sdk, job).await?
};
let path: std::path::PathBuf = args.program_output;
std::fs::write(path, job)?;
}

Ok(())
}