Skip to content

Commit

Permalink
clippy and fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
chudkowsky committed Sep 17, 2024
1 parent bc13c9c commit 3d46b39
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 94 deletions.
22 changes: 15 additions & 7 deletions common/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct JWTResponse {
pub expiration: u64,
pub session_key: Option<VerifyingKey>,
}
#[derive(Serialize, Deserialize, Clone,Debug)]
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum JobStatus {
Pending,
Running,
Expand All @@ -20,17 +20,25 @@ pub enum JobStatus {
Unknown,
}

#[derive(Clone,Serialize,Deserialize)]
#[derive(Clone, Serialize, Deserialize)]
pub struct ProverResult {
pub proof: String,
pub program_hash: Felt,
pub program_output: Vec<Felt>,
pub program_output_hash: Felt,
}
#[derive(Serialize,Deserialize)]
#[derive(Serialize, Deserialize)]
#[serde(untagged)]
pub enum JobResponse {
InProgress { id: u64, status: JobStatus },
Completed { result: ProverResult, status: JobStatus },
Failed { error: String },
}
InProgress {
id: u64,
status: JobStatus,
},
Completed {
result: ProverResult,
status: JobStatus,
},
Failed {
error: String,
},
}
6 changes: 3 additions & 3 deletions prover-sdk/tests/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ pub async fn fetch_job(sdk: ProverSDK, job: u64) -> Option<ProverResult> {
let response = sdk.get_job(job).await.unwrap();
let response = response.text().await.unwrap();
let json_response: JobResponse = serde_json::from_str(&response).unwrap();

if let JobResponse::Completed { result, .. } = json_response {
return Some(result);
}
None
}
None
}
126 changes: 60 additions & 66 deletions prover-sdk/tests/prove_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::fs;

use common::prover_input::*;
use helpers::fetch_job;
use prover_sdk::{access_key::ProverAccessKey, sdk::ProverSDK};
Expand Down Expand Up @@ -42,67 +40,63 @@ async fn test_cairo_prove() {
assert_eq!("true", result.unwrap());
}

// #[tokio::test]
// async fn test_cairo0_prove() {
// let private_key = std::env::var("PRIVATE_KEY").unwrap();
// let url = std::env::var("PROVER_URL").unwrap();
// let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap();
// let url = Url::parse(&url).unwrap();
// let sdk = ProverSDK::new(url, access_key).await.unwrap();
// let program = std::fs::read_to_string("../examples/cairo0/fibonacci_compiled.json").unwrap();
// let program: Cairo0CompiledProgram = serde_json::from_str(&program).unwrap();
// let program_input_string = std::fs::read_to_string("../examples/cairo0/input.json").unwrap();
// let program_input: Value = serde_json::from_str(&program_input_string).unwrap();
// let layout = "recursive".to_string();
// let data = Cairo0ProverInput {
// program,
// layout,
// program_input,
// n_queries: Some(16),
// pow_bits: Some(20),
// };
// let job = sdk.prove_cairo0(data).await.unwrap();
// let result = fetch_job(sdk.clone(), job).await;
// let job = sdk.clone().verify(result).await.unwrap();
// let result = fetch_job(sdk.clone(), job).await;
// assert_eq!("true", result);
// }
// #[tokio::test]
// async fn test_cairo_multi_prove() {
// let private_key = std::env::var("PRIVATE_KEY").unwrap();
// let url = std::env::var("PROVER_URL").unwrap();
// let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap();
// let url = Url::parse(&url).unwrap();
// let sdk = ProverSDK::new(url, access_key).await.unwrap();
// let program = std::fs::read_to_string("../examples/cairo/fibonacci_compiled.json").unwrap();
// let program: CairoCompiledProgram = serde_json::from_str(&program).unwrap();
// let program_input_string = std::fs::read_to_string("../examples/cairo/input.json").unwrap();
// let mut program_input: Vec<Felt> = Vec::new();
// for part in program_input_string.split(',') {
// let felt = Felt::from_dec_str(part).unwrap();
// program_input.push(felt);
// }
// let layout = "recursive".to_string();
// let data = CairoProverInput {
// program,
// layout,
// program_input,
// n_queries: Some(16),
// pow_bits: Some(20),
// };
// let job1 = sdk.prove_cairo(data.clone()).await.unwrap();
// let job2 = sdk.prove_cairo(data.clone()).await.unwrap();
// let job3 = sdk.prove_cairo(data.clone()).await.unwrap();
// let result = fetch_job(sdk.clone(), job1).await;
// let job = sdk.clone().verify(result).await.unwrap();
// let result = fetch_job(sdk.clone(), job).await;
// assert_eq!("true", result);
// let result = fetch_job(sdk.clone(), job2).await;
// let job = sdk.clone().verify(result).await.unwrap();
// let result = fetch_job(sdk.clone(), job).await;
// assert_eq!("true", result);
// let result = fetch_job(sdk.clone(), job3).await;
// let job = sdk.clone().verify(result).await.unwrap();
// let result = fetch_job(sdk.clone(), job).await;
// assert_eq!("true", result);
// }
#[tokio::test]
async fn test_cairo0_prove() {
let private_key = std::env::var("PRIVATE_KEY").unwrap();
let url = std::env::var("PROVER_URL").unwrap();
let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap();
let url = Url::parse(&url).unwrap();
let sdk = ProverSDK::new(url, access_key).await.unwrap();
let program = std::fs::read_to_string("../examples/cairo0/fibonacci_compiled.json").unwrap();
let program: Cairo0CompiledProgram = serde_json::from_str(&program).unwrap();
let program_input_string = std::fs::read_to_string("../examples/cairo0/input.json").unwrap();
let program_input: Value = serde_json::from_str(&program_input_string).unwrap();
let layout = "recursive".to_string();
let data = Cairo0ProverInput {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job = sdk.prove_cairo0(data).await.unwrap();
let result = fetch_job(sdk.clone(), job).await;
let result = sdk.clone().verify(result.unwrap().proof).await.unwrap();
assert_eq!("true", result);
}
#[tokio::test]
async fn test_cairo_multi_prove() {
let private_key = std::env::var("PRIVATE_KEY").unwrap();
let url = std::env::var("PROVER_URL").unwrap();
let access_key = ProverAccessKey::from_hex_string(&private_key).unwrap();
let url = Url::parse(&url).unwrap();
let sdk = ProverSDK::new(url, access_key).await.unwrap();
let program = std::fs::read_to_string("../examples/cairo/fibonacci_compiled.json").unwrap();
let program: CairoCompiledProgram = serde_json::from_str(&program).unwrap();
let program_input_string = std::fs::read_to_string("../examples/cairo/input.json").unwrap();
let mut program_input: Vec<Felt> = Vec::new();
for part in program_input_string.split(',') {
let felt = Felt::from_dec_str(part).unwrap();
program_input.push(felt);
}
let layout = "recursive".to_string();
let data = CairoProverInput {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job1 = sdk.prove_cairo(data.clone()).await.unwrap();
let job2 = sdk.prove_cairo(data.clone()).await.unwrap();
let job3 = sdk.prove_cairo(data.clone()).await.unwrap();
let result = fetch_job(sdk.clone(), job1).await;
let result = sdk.clone().verify(result.unwrap().proof).await.unwrap();
assert_eq!("true", result);
let result = fetch_job(sdk.clone(), job2).await;
let result = sdk.clone().verify(result.unwrap().proof).await.unwrap();
assert_eq!("true", result);
let result = fetch_job(sdk.clone(), job3).await;
let result = sdk.clone().verify(result.unwrap().proof).await.unwrap();
assert_eq!("true", result);
}
6 changes: 3 additions & 3 deletions prover-sdk/tests/verify_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ async fn test_verify_invalid_proof() {
let url = Url::parse(&url).unwrap();
let sdk = ProverSDK::new(url, access_key).await.unwrap();
let result = sdk.clone().verify("wrong proof".to_string()).await;
assert!(result.is_ok(),"Failed to verify proof");
assert_eq!("false", result.unwrap());
assert!(result.is_ok(), "Failed to verify proof");
assert_eq!("false", result.unwrap());
}

#[tokio::test]
Expand Down Expand Up @@ -46,6 +46,6 @@ async fn test_verify_valid_proof() {
assert!(result.is_some());
let result = result.unwrap();
let result = sdk.clone().verify(result.proof).await;
assert!(result.is_ok(),"Failed to verify proof");
assert!(result.is_ok(), "Failed to verify proof");
assert_eq!("true", result.unwrap());
}
12 changes: 9 additions & 3 deletions prover/src/threadpool/prove.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::run::RunPaths;
use super::CairoVersionedInput;
use crate::errors::ProverError;
use crate::utils::proof_parser::{extract_program_hash, extract_program_output, program_output_hash};
use crate::utils::proof_parser::{
extract_program_hash, extract_program_output, program_output_hash,
};
use crate::utils::{config::Template, job::JobStore};
use common::models::{JobStatus, ProverResult};
use serde_json::Value;
Expand Down Expand Up @@ -48,15 +50,19 @@ pub async fn prove(
let program_hash = extract_program_hash(stark_proof.clone());
let program_output = extract_program_output(stark_proof.clone());
let program_output_hash = program_output_hash(program_output.clone());
let prover_result = ProverResult{
let prover_result = ProverResult {
proof: final_result,
program_hash,
program_output,
program_output_hash,
};

job_store
.update_job_status(job_id, JobStatus::Completed,serde_json::to_string_pretty(&prover_result).ok())
.update_job_status(
job_id,
JobStatus::Completed,
serde_json::to_string_pretty(&prover_result).ok(),
)
.await;
if sender.receiver_count() > 0 {
sender
Expand Down
21 changes: 13 additions & 8 deletions prover/src/utils/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,20 @@ pub struct Job {
pub created: Instant,
}

#[derive(Serialize,Deserialize)]
#[derive(Serialize, Deserialize)]
#[serde(untagged)]
pub enum JobResponse {
InProgress { id: u64, status: JobStatus },
Completed { result: ProverResult, status: JobStatus },
Failed { error: String },
InProgress {
id: u64,
status: JobStatus,
},
Completed {
result: ProverResult,
status: JobStatus,
},
Failed {
error: String,
},
}

#[derive(Default, Clone)]
Expand Down Expand Up @@ -113,10 +121,7 @@ pub async fn get_job(
StatusCode::OK,
Json(JobResponse::Completed {
status: job.status.clone(),
result: serde_json::from_str(&job
.result
.clone()
.unwrap()).unwrap(),
result: serde_json::from_str(&job.result.clone().unwrap()).unwrap(),
}),
),
JobStatus::Failed => (
Expand Down
2 changes: 1 addition & 1 deletion prover/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod config;
pub mod job;
pub mod proof_parser;
pub mod shutdown;
pub mod proof_parser;
10 changes: 7 additions & 3 deletions prover/src/verifier.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use crate::{auth::jwt::Claims, errors::ProverError, extractors::workdir::TempDirHandle};
use axum::{response::IntoResponse, Json};
use crate::{auth::jwt::Claims, extractors::workdir::TempDirHandle};
use axum::Json;

use std::process::Command;

pub async fn verify_proof(TempDirHandle(dir):TempDirHandle,_claims:Claims,Json(proof): Json<String>) -> Json<bool> {
pub async fn verify_proof(
TempDirHandle(dir): TempDirHandle,
_claims: Claims,
Json(proof): Json<String>,
) -> Json<bool> {
// Define the path for the proof file
let file = dir.into_path().join("proof");

Expand Down

0 comments on commit 3d46b39

Please sign in to comment.