Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable pow_bits and n_queries #62

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bin/cairo-prove/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ pub struct Args {
pub wait: bool,
#[arg(long, env, default_value = "false")]
pub sse: bool,
#[arg(long, env)]
pub n_queries: Option<u32>,
#[arg(long, env)]
pub pow_bits: Option<u32>,
}

fn validate_input(input: &str) -> Result<Vec<Felt>, ProveErrors> {
Expand Down
4 changes: 4 additions & 0 deletions bin/cairo-prove/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result<u64, ProveErrors> {
program: program_serialized,
layout: args.layout,
program_input,
pow_bits: args.pow_bits,
n_queries: args.n_queries,
};
sdk.prove_cairo0(data).await?
}
Expand All @@ -38,6 +40,8 @@ pub async fn prove(args: Args, sdk: ProverSDK) -> Result<u64, ProveErrors> {
program: program_serialized,
layout: args.layout,
program_input: input,
pow_bits: args.pow_bits,
n_queries: args.n_queries,
};
sdk.prove_cairo(data).await?
}
Expand Down
2 changes: 2 additions & 0 deletions common/src/prover_input/cairo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub struct CairoProverInput {
pub program: CairoCompiledProgram,
pub program_input: Vec<Felt>,
pub layout: String,
pub n_queries: Option<u32>,
pub pow_bits: Option<u32>,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
Expand Down
2 changes: 2 additions & 0 deletions common/src/prover_input/cairo0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pub struct Cairo0ProverInput {
pub program: Cairo0CompiledProgram,
pub program_input: serde_json::Value,
pub layout: String,
pub n_queries: Option<u32>,
pub pow_bits: Option<u32>,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
Expand Down
6 changes: 6 additions & 0 deletions prover-sdk/tests/prove_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ async fn test_cairo_prove() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job = sdk.prove_cairo(data).await.unwrap();
let result = fetch_job(sdk.clone(), job).await;
Expand All @@ -50,6 +52,8 @@ async fn test_cairo0_prove() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job = sdk.prove_cairo0(data).await.unwrap();
let result = fetch_job(sdk.clone(), job).await;
Expand Down Expand Up @@ -77,6 +81,8 @@ async fn test_cairo_multi_prove() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job1 = sdk.prove_cairo(data.clone()).await.unwrap();
let job2 = sdk.prove_cairo(data.clone()).await.unwrap();
Expand Down
2 changes: 2 additions & 0 deletions prover-sdk/tests/verify_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ async fn test_verify_valid_proof() {
program,
layout,
program_input,
n_queries: Some(16),
pow_bits: Some(20),
};
let job = sdk.clone().prove_cairo(data).await.unwrap();
let result = fetch_job(sdk.clone(), job).await;
Expand Down
24 changes: 12 additions & 12 deletions prover/src/prove/cairo.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
use crate::auth::jwt::Claims;
use crate::extractors::workdir::TempDirHandle;
use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use crate::threadpool::{CairoVersionedInput, ExecuteParams};
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::prover_input::CairoProverInput;
use serde_json::json;

pub async fn root(
State(app_state): State<AppState>,
TempDirHandle(path): TempDirHandle,
TempDirHandle(dir): TempDirHandle,
_claims: Claims,
Json(program_input): Json<CairoProverInput>,
) -> impl IntoResponse {
let thread_pool = app_state.thread_pool.clone();
let job_store = app_state.job_store.clone();
let job_id = job_store.create_job().await;
let thread = thread_pool.lock().await;
thread
.execute(
job_id,
job_store,
path,
CairoVersionedInput::Cairo(program_input),
app_state.sse_tx.clone(),
)
.await
.into_response();
let execution_params = ExecuteParams {
job_id,
job_store,
dir,
program_input: CairoVersionedInput::Cairo(program_input.clone()),
sse_tx: app_state.sse_tx.clone(),
n_queries: program_input.clone().n_queries,
pow_bits: program_input.pow_bits,
};
thread.execute(execution_params).await.into_response();

let body = json!({
"job_id": job_id
Expand Down
24 changes: 12 additions & 12 deletions prover/src/prove/cairo0.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
use crate::auth::jwt::Claims;
use crate::extractors::workdir::TempDirHandle;
use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use crate::threadpool::{CairoVersionedInput, ExecuteParams};
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::prover_input::Cairo0ProverInput;
use serde_json::json;

pub async fn root(
State(app_state): State<AppState>,
TempDirHandle(path): TempDirHandle,
TempDirHandle(dir): TempDirHandle,
_claims: Claims,
Json(program_input): Json<Cairo0ProverInput>,
) -> impl IntoResponse {
let thread_pool = app_state.thread_pool.clone();
let job_store = app_state.job_store.clone();
let job_id = job_store.create_job().await;
let thread = thread_pool.lock().await;
thread
.execute(
job_id,
job_store,
path,
CairoVersionedInput::Cairo0(program_input),
app_state.sse_tx.clone(),
)
.await
.into_response();
let execution_params = ExecuteParams {
job_id,
job_store,
dir,
program_input: CairoVersionedInput::Cairo0(program_input.clone()),
sse_tx: app_state.sse_tx.clone(),
n_queries: program_input.clone().n_queries,
pow_bits: program_input.pow_bits,
};
thread.execute(execution_params).await.into_response();
let body = json!({
"job_id": job_id
});
Expand Down
46 changes: 35 additions & 11 deletions prover/src/threadpool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type ReceiverType = Arc<
TempDir,
CairoVersionedInput,
Arc<Mutex<Sender<String>>>,
Option<u32>,
Option<u32>,
)>,
>,
>;
Expand All @@ -32,8 +34,19 @@ type SenderType = Option<
TempDir,
CairoVersionedInput,
Arc<Mutex<Sender<String>>>,
Option<u32>,
Option<u32>,
)>,
>;
pub struct ExecuteParams {
pub job_id: u64,
pub job_store: JobStore,
pub dir: TempDir,
pub program_input: CairoVersionedInput,
pub sse_tx: Arc<Mutex<Sender<String>>>,
pub n_queries: Option<u32>,
pub pow_bits: Option<u32>,
}
pub struct ThreadPool {
workers: Vec<Worker>,
sender: SenderType,
Expand All @@ -59,20 +72,21 @@ impl ThreadPool {
}
}

pub async fn execute(
&self,
job_id: u64,
job_store: JobStore,
dir: TempDir,
program_input: CairoVersionedInput,
sse_tx: Arc<Mutex<Sender<String>>>,
) -> Result<(), ProverError> {
pub async fn execute(&self, params: ExecuteParams) -> Result<(), ProverError> {
self.sender
.as_ref()
.ok_or(ProverError::CustomError(
"Thread pool is shutdown".to_string(),
))?
.send((job_id, job_store, dir, program_input, sse_tx))
.send((
params.job_id,
params.job_store,
params.dir,
params.program_input,
params.sse_tx,
params.n_queries,
params.pow_bits,
))
.await?;
Ok(())
}
Expand Down Expand Up @@ -107,10 +121,20 @@ impl Worker {
loop {
let message = receiver.lock().await.recv().await;
match message {
Some((job_id, job_store, dir, program_input, sse_tx)) => {
Some((job_id, job_store, dir, program_input, sse_tx, n_queries, pow_bits)) => {
trace!("Worker {id} got a job; executing.");

if let Err(e) = prove(job_id, job_store, dir, program_input, sse_tx).await {
if let Err(e) = prove(
job_id,
job_store,
dir,
program_input,
sse_tx,
n_queries,
pow_bits,
)
.await
{
eprintln!("Worker {id} encountered an error: {:?}", e);
}

Expand Down
5 changes: 3 additions & 2 deletions prover/src/threadpool/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub async fn prove(
dir: TempDir,
program_input: CairoVersionedInput,
sse_tx: Arc<Mutex<Sender<String>>>,
n_queries: Option<u32>,
pow_bits: Option<u32>,
) -> Result<(), ProverError> {
job_store
.update_job_status(job_id, JobStatus::Running, None)
Expand All @@ -29,8 +31,7 @@ pub async fn prove(
program_input
.prepare_and_run(&RunPaths::from(&paths))
.await?;

Template::generate_from_public_input_file(&paths.public_input_file)?
Template::generate_from_public_input_file(&paths.public_input_file, n_queries, pow_bits)?
.save_to_file(&paths.params_file)?;

let prove_status = paths.prove_command().spawn()?.wait().await?;
Expand Down
20 changes: 18 additions & 2 deletions prover/src/utils/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,16 @@ pub struct Template {
}

impl Template {
pub fn generate_from_public_input_file(file: &PathBuf) -> Result<Self, ProverError> {
Self::generate_from_public_input(ProgramPublicInputAsNSteps::read_from_file(file)?)
pub fn generate_from_public_input_file(
file: &PathBuf,
n_queries: Option<u32>,
pow_bits: Option<u32>,
) -> Result<Self, ProverError> {
Self::generate_from_public_input(
ProgramPublicInputAsNSteps::read_from_file(file)?,
n_queries,
pow_bits,
)
}
pub fn save_to_file(&self, file: &PathBuf) -> Result<(), ProverError> {
let json_string = serde_json::to_string_pretty(self)?;
Expand All @@ -46,8 +54,16 @@ impl Template {
}
fn generate_from_public_input(
public_input: ProgramPublicInputAsNSteps,
n_queries: Option<u32>,
pow_bits: Option<u32>,
) -> Result<Self, ProverError> {
let mut template = Self::default();
if let Some(pow_bits) = pow_bits {
template.stark.fri.proof_of_work_bits = pow_bits;
}
if let Some(n_queries) = n_queries {
template.stark.fri.n_queries = n_queries;
}
let fri_step_list =
public_input.calculate_fri_step_list(template.stark.fri.last_layer_degree_bound);
template.stark.fri.fri_step_list = fri_step_list;
Expand Down
10 changes: 8 additions & 2 deletions scripts/e2e_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
set -eux
IMAGE_NAME="http-prover-test"
CONTAINER_ENGINE="${CONTAINER_ENGINE:-docker}"

# Check if the image already exists
if $CONTAINER_ENGINE images | grep -q "$IMAGE_NAME"; then
echo "Image $IMAGE_NAME already exists. Skipping build step."
Expand Down Expand Up @@ -46,9 +45,16 @@ $CONTAINER_ENGINE run -d --name http_prover_test $REPLACE_FLAG \
--message-expiration-time 3600 \
--session-expiration-time 3600 \
--authorized-keys $PUBLIC_KEY,$ADMIN_PUBLIC_KEY \
--admin-key $ADMIN_PUBLIC_KEY
--admin-key $ADMIN_PUBLIC_KEY

start_time=$(date +%s)

PRIVATE_KEY=$PRIVATE_KEY PROVER_URL="http://localhost:3040" ADMIN_PRIVATE_KEY=$ADMIN_PRIVATE_KEY cargo test --no-fail-fast --workspace --verbose

end_time=$(date +%s)

runtime=$((end_time - start_time))

echo "Total time for running tests: $runtime seconds"
$CONTAINER_ENGINE stop http_prover_test
$CONTAINER_ENGINE rm http_prover_test