Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: generator serialization #79

Merged
merged 7 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 13 additions & 45 deletions curta/src/chip/hash/sha/sha256/builder_gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,18 @@ use plonky2::field::extension::Extendable;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::target::Target;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use serde::{Deserialize, Serialize};

use super::generator::{SHA256AirParameters, SHA256Generator, SHA256HintGenerator};
use super::generator::{SHA256Generator, SHA256HintGenerator, SHA256StarkData};
use super::SHA256PublicData;
use crate::chip::builder::AirBuilder;
use crate::chip::trace::generator::ArithmeticGenerator;
use crate::chip::AirParameters;
use crate::math::prelude::CubicParameters;
use crate::plonky2::stark::config::{CurtaConfig, StarkyConfig};
use crate::plonky2::stark::config::CurtaConfig;
use crate::plonky2::stark::gadget::StarkGadget;
use crate::plonky2::stark::generator::simple::SimpleStarkWitnessGenerator;
use crate::plonky2::stark::Starky;

#[derive(Debug, Clone, Copy)]
pub struct CurtaBytes<const N: usize>(pub [Target; N]);

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SHA256BuilderGadget<F, E, const D: usize> {
pub padded_messages: Vec<Target>,
pub digests: Vec<Target>,
Expand Down Expand Up @@ -80,52 +76,24 @@ impl<F: RichField + Extendable<D>, E: CubicParameters<F>, const D: usize> SHA256
let public_sha_targets =
SHA256PublicData::add_virtual(self, &gadget.digests, &gadget.chunk_sizes);

// Make the air
let mut air_builder = AirBuilder::<SHA256AirParameters<F, E>>::new();
let clk = air_builder.clock();

let (mut operations, table) = air_builder.byte_operations();

let mut bus = air_builder.new_bus();
let channel_idx = bus.new_channel(&mut air_builder);

let sha_gadget =
air_builder.process_sha_256_batch(&clk, &mut bus, channel_idx, &mut operations);

air_builder.register_byte_lookup(operations, &table);
air_builder.constrain_bus(bus);

let (air, trace_data) = air_builder.build();

let generator = ArithmeticGenerator::<SHA256AirParameters<F, E>>::new(trace_data);
let stark_data = SHA256Generator::<F, E, C, D>::stark_data();
let SHA256StarkData { stark, config, .. } = stark_data;

let public_input_target = public_sha_targets.public_input_targets(self);
let virtual_proof = self.add_virtual_stark_proof(&stark, &config);
self.verify_stark_proof(&config, &stark, &virtual_proof, &public_input_target);

let sha_generator = SHA256Generator {
gadget: sha_gadget,
table,
let sha_generator = SHA256Generator::<F, E, C, D> {
padded_messages: gadget.padded_messages,
chunk_sizes: gadget.chunk_sizes,
trace_generator: generator.clone(),
pub_values_target: public_sha_targets,
config,
proof_target: virtual_proof,
public_input_targets: public_input_target,
_marker: PhantomData,
};

self.add_simple_generator(sha_generator);

let stark = Starky::new(air);
let config =
StarkyConfig::<C, D>::standard_fast_config(SHA256AirParameters::<F, E>::num_rows());
let virtual_proof = self.add_virtual_stark_proof(&stark, &config);
self.verify_stark_proof(&config, &stark, &virtual_proof, &public_input_target);

let stark_generator = SimpleStarkWitnessGenerator::new(
config,
stark,
virtual_proof,
public_input_target,
generator,
);
self.add_simple_generator(stark_generator);
}
}

Expand Down
98 changes: 85 additions & 13 deletions curta/src/chip/hash/sha/sha256/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@ use plonky2::util::serialization::{Buffer, Read, Write};
use serde::{Deserialize, Serialize};

use super::{SHA256Gadget, SHA256PublicData, INITIAL_HASH, ROUND_CONSTANTS};
use crate::chip::builder::AirBuilder;
use crate::chip::register::Register;
use crate::chip::trace::generator::ArithmeticGenerator;
use crate::chip::uint::bytes::lookup_table::table::ByteLookupTable;
use crate::chip::uint::operations::instruction::U32Instruction;
use crate::chip::uint::register::U32Register;
use crate::chip::uint::util::u32_to_le_field_bytes;
use crate::chip::AirParameters;
use crate::chip::{AirParameters, Chip};
use crate::math::prelude::{CubicParameters, *};
use crate::plonky2::stark::config::{CurtaConfig, StarkyConfig};
use crate::plonky2::stark::proof::StarkProofTarget;
use crate::plonky2::stark::prover::StarkyProver;
use crate::plonky2::stark::verifier::set_stark_proof_target;
use crate::plonky2::stark::Starky;
use crate::utils::serde::{BufferRead, BufferWrite};

#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
Expand All @@ -43,13 +49,22 @@ pub struct SHA256HintGenerator {

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct SHA256Generator<F: PrimeField64, E: CubicParameters<F>> {
pub gadget: SHA256Gadget,
pub table: ByteLookupTable,
pub struct SHA256Generator<F: PrimeField64, E: CubicParameters<F>, C, const D: usize> {
pub padded_messages: Vec<Target>,
pub chunk_sizes: Vec<usize>,
pub trace_generator: ArithmeticGenerator<SHA256AirParameters<F, E>>,
pub pub_values_target: SHA256PublicData<Target>,
pub config: StarkyConfig<C, D>,
pub proof_target: StarkProofTarget<D>,
pub public_input_targets: Vec<Target>,
pub _marker: PhantomData<(F, E)>,
}

pub struct SHA256StarkData<F: PrimeField64, E: CubicParameters<F>, C, const D: usize> {
pub stark: Starky<Chip<SHA256AirParameters<F, E>>>,
pub table: ByteLookupTable,
pub trace_generator: ArithmeticGenerator<SHA256AirParameters<F, E>>,
pub config: StarkyConfig<C, D>,
pub gadget: SHA256Gadget,
}

impl<F: PrimeField64, E: CubicParameters<F>> AirParameters for SHA256AirParameters<F, E> {
Expand All @@ -67,14 +82,55 @@ impl<F: PrimeField64, E: CubicParameters<F>> AirParameters for SHA256AirParamete
}
}

impl<F: RichField, E: CubicParameters<F>> SHA256Generator<F, E> {
impl<F: PrimeField64, E: CubicParameters<F>, C, const D: usize> SHA256Generator<F, E, C, D> {
pub fn id() -> String {
"SHA256Generator".to_string()
}

pub fn stark_data() -> SHA256StarkData<F, E, C, D>
where
F: RichField + Extendable<D>,
C: CurtaConfig<D, F = F>,
E: CubicParameters<F>,
{
let mut air_builder = AirBuilder::<SHA256AirParameters<F, E>>::new();
let clk = air_builder.clock();

let (mut operations, table) = air_builder.byte_operations();

let mut bus = air_builder.new_bus();
let channel_idx = bus.new_channel(&mut air_builder);

let gadget =
air_builder.process_sha_256_batch(&clk, &mut bus, channel_idx, &mut operations);

air_builder.register_byte_lookup(operations, &table);
air_builder.constrain_bus(bus);

let (air, trace_data) = air_builder.build();

let stark = Starky::new(air);
let config =
StarkyConfig::<C, D>::standard_fast_config(SHA256AirParameters::<F, E>::num_rows());

let trace_generator = ArithmeticGenerator::<SHA256AirParameters<F, E>>::new(trace_data);

SHA256StarkData {
stark,
table,
trace_generator,
config,
gadget,
}
}
}

impl<F: RichField + Extendable<D>, E: CubicParameters<F>, const D: usize> SimpleGenerator<F, D>
for SHA256Generator<F, E>
impl<
F: RichField + Extendable<D>,
C: CurtaConfig<D, F = F>,
E: CubicParameters<F>,
const D: usize,
> SimpleGenerator<F, D> for SHA256Generator<F, E, C, D>
{
fn id(&self) -> String {
Self::id()
Expand Down Expand Up @@ -106,6 +162,14 @@ impl<F: RichField + Extendable<D>, E: CubicParameters<F>, const D: usize> Simple
}

fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let SHA256StarkData {
stark,
table,
trace_generator,
config,
gadget,
} = Self::stark_data();

let padded_messages = self
.padded_messages
.iter()
Expand All @@ -120,17 +184,25 @@ impl<F: RichField + Extendable<D>, E: CubicParameters<F>, const D: usize> Simple
});

// Write trace values
let writer = self.trace_generator.new_writer();
self.table.write_table_entries(&writer);
let sha_public_values = self.gadget.write(message_chunks, &writer);
let writer = trace_generator.new_writer();
table.write_table_entries(&writer);
let sha_public_values = gadget.write(message_chunks, &writer);
for i in 0..SHA256AirParameters::<F, E>::num_rows() {
writer.write_row_instructions(&self.trace_generator.air_data, i);
writer.write_row_instructions(&trace_generator.air_data, i);
}
self.table.write_multiplicities(&writer);
table.write_multiplicities(&writer);

// Fill sha public values into the output buffer
self.pub_values_target
.set_targets(sha_public_values, out_buffer);

let public_inputs: Vec<_> = writer.public.read().unwrap().clone();

let proof =
StarkyProver::<F, C, D>::prove(&config, &stark, &trace_generator, &public_inputs)
.unwrap();

set_stark_proof_target(out_buffer, &self.proof_target, &proof);
}
}

Expand Down
19 changes: 19 additions & 0 deletions curta/src/chip/trace/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ pub struct ArithmeticGenerator<L: AirParameters> {
}

impl<L: AirParameters> ArithmeticGenerator<L> {
pub fn reset(&self) {
let trace_new = AirTrace::new_with_capacity(L::num_columns(), L::num_rows());
let global_new = Vec::new();
let challenges_new = Vec::new();
let public_new = Vec::new();

let mut trace = self.writer.0.trace.write().unwrap();
*trace = trace_new;

let mut global = self.writer.0.global.write().unwrap();
*global = global_new;

let mut challenges = self.writer.0.challenges.write().unwrap();
*challenges = challenges_new;

let mut public = self.writer.0.public.write().unwrap();
*public = public_new;
}

pub fn new(air_data: AirTraceData<L>) -> Self {
let num_public_inputs = air_data.num_public_inputs;
let num_global_values = air_data.num_global_values;
Expand Down
2 changes: 1 addition & 1 deletion curta/src/chip/trace/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::trace::AirTrace;

#[derive(Debug, Serialize, Deserialize)]
pub struct WriterData<T> {
trace: RwLock<AirTrace<T>>,
pub(crate) trace: RwLock<AirTrace<T>>,
pub(crate) global: RwLock<Vec<T>>,
pub(crate) public: RwLock<Vec<T>>,
pub(crate) challenges: RwLock<Vec<T>>,
Expand Down
4 changes: 2 additions & 2 deletions curta/src/chip/uint/bytes/lookup_table/multiplicity_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ use crate::math::prelude::*;
use crate::maybe_rayon::*;

#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplicityValues(Vec<[AtomicUsize; NUM_BIT_OPPS + 1]>);
pub struct MultiplicityValues(pub Vec<[AtomicUsize; NUM_BIT_OPPS + 1]>);

#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplicityData {
multiplicities: ArrayRegister<ElementRegister>,
multiplicities_values: MultiplicityValues,
pub multiplicities_values: MultiplicityValues,
operations_multipcitiy_dict: HashMap<ByteOperation<u8>, (usize, usize)>,
pub operations_dict: HashMap<usize, Vec<ByteOperation<u8>>>,
}
Expand Down
11 changes: 11 additions & 0 deletions curta/src/chip/uint/bytes/lookup_table/table.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::sync::Arc;
use core::array::from_fn;
use core::sync::atomic::Ordering;

use itertools::Itertools;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -41,6 +42,16 @@ pub struct ByteLookupTable {
pub digests: Vec<CubicRegister>,
}

impl ByteLookupTable {
pub fn reset(&self) {
for v in self.multiplicity_data.multiplicities_values.0.iter() {
for q in v {
q.store(0, Ordering::SeqCst);
}
}
}
}

impl<L: AirParameters> AirBuilder<L> {
pub fn new_byte_lookup_table(
&mut self,
Expand Down
4 changes: 4 additions & 0 deletions curta/src/plonky2/stark/generator/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ where
) {
let public_inputs = witness.get_targets(&self.public_input_targets);

println!("{:?}", self.trace_generator.writer.public);

let proof = StarkyProver::<L::Field, C, D>::prove(
&self.config,
&self.stark,
Expand All @@ -88,6 +90,8 @@ where
.unwrap();

set_stark_proof_target(out_buffer, &self.proof_target, &proof);

self.trace_generator.reset();
}

fn serialize(
Expand Down
2 changes: 1 addition & 1 deletion curta/src/plonky2/stark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ pub(crate) mod tests {
use plonky2::util::timing::TimingTree;
use serde::de::DeserializeOwned;

use super::generator::simple::SimpleStarkWitnessGenerator;
use super::*;
use crate::air::fibonacci::FibonacciAir;
use crate::chip::builder::tests::ArithmeticGenerator;
use crate::chip::{AirParameters, Chip};
use crate::math::prelude::*;
use crate::plonky2::stark::config::PoseidonGoldilocksStarkConfig;
use crate::plonky2::stark::gadget::StarkGadget;
use crate::plonky2::stark::generator::simple::SimpleStarkWitnessGenerator;
use crate::plonky2::stark::prover::StarkyProver;
use crate::plonky2::stark::verifier::StarkyVerifier;
use crate::plonky2::{Plonky2Air, StarkyAir};
Expand Down