Skip to content

Commit

Permalink
perf: reduce byte lookup columns (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo authored Sep 26, 2023
1 parent cb93984 commit 022f376
Show file tree
Hide file tree
Showing 17 changed files with 282 additions and 168 deletions.
5 changes: 2 additions & 3 deletions curta/src/chip/hash/blake/blake2b/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ impl<F: PrimeField64, E: CubicParameters<F>> AirParameters for BLAKE2BAirParamet

type Instruction = U32Instruction;

const NUM_FREE_COLUMNS: usize = 2493;
const EXTENDED_COLUMNS: usize = 4755;
const NUM_FREE_COLUMNS: usize = 3539;
const EXTENDED_COLUMNS: usize = 1617;
const NUM_ARITHMETIC_COLUMNS: usize = 0;
}

Expand Down Expand Up @@ -205,7 +205,6 @@ impl<
for i in 0..num_rows {
writer.write_row_instructions(&trace_generator.air_data, i);
}
table.write_multiplicities(&writer);

// Fill blake2b public values into the output buffer
self.pub_values_target
Expand Down
1 change: 0 additions & 1 deletion curta/src/chip/hash/blake/blake2b/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,6 @@ mod tests {
msg_to_check %= msgs.len();
}
}
table.write_multiplicities(&writer);
});

let public_inputs = writer.0.public.read().unwrap().clone();
Expand Down
7 changes: 2 additions & 5 deletions curta/src/chip/hash/sha/sha256/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ pub struct SHA256AirParameters<F, E>(pub PhantomData<(F, E)>);

pub type U32Target = <U32Register as Register>::Value<Target>;

pub const SHA256_COLUMNS: usize = 551 + 927;

#[derive(Debug, Clone)]
pub struct MessageChunks {
pub values: Vec<u8>,
Expand Down Expand Up @@ -72,8 +70,8 @@ impl<F: PrimeField64, E: CubicParameters<F>> AirParameters for SHA256AirParamete

type Instruction = U32Instruction;

const NUM_FREE_COLUMNS: usize = 551;
const EXTENDED_COLUMNS: usize = 927;
const NUM_FREE_COLUMNS: usize = 745;
const EXTENDED_COLUMNS: usize = 345;
const NUM_ARITHMETIC_COLUMNS: usize = 0;
}

Expand Down Expand Up @@ -187,7 +185,6 @@ impl<
for i in 0..num_rows {
writer.write_row_instructions(&trace_generator.air_data, i);
}
table.write_multiplicities(&writer);

// Fill sha public values into the output buffer
self.pub_values_target
Expand Down
8 changes: 5 additions & 3 deletions curta/src/chip/hash/sha/sha256/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,14 @@ mod tests {

type Instruction = U32Instruction;

const NUM_FREE_COLUMNS: usize = 551;
const EXTENDED_COLUMNS: usize = 927;
const NUM_FREE_COLUMNS: usize = 745;
const EXTENDED_COLUMNS: usize = 345;
const NUM_ARITHMETIC_COLUMNS: usize = 0;
}

// 551 + 927
// 745 + 345

#[test]
fn test_sha_256_stark() {
type F = GoldilocksField;
Expand Down Expand Up @@ -690,7 +693,6 @@ mod tests {
assert_eq!(hash, digest.map(u32_to_le_field_bytes));
}
}
table.write_multiplicities(&writer);
});

let public_inputs = writer.0.public.read().unwrap().clone();
Expand Down
45 changes: 45 additions & 0 deletions curta/src/chip/trace/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,51 @@ impl<F: Field> TraceWriter<F> {
self.write_instruction(instruction, 0);
}
}

/// An atomic fetch and modify operation on a register.
#[inline]
pub fn fetch_and_modify<T: Register>(
&self,
data: &T,
op: impl FnOnce(&T::Value<F>) -> T::Value<F>,
row_index: usize,
) {
match data.register() {
MemorySlice::Local(..) => {
let mut trace = self.0.trace.write().unwrap();
let window = trace.window(row_index);
let parser = TraceWindowParser::new(window, &[], &[], &[]);
let value = data.eval(&parser);

let new_value = op(&value);
data.register()
.assign(&mut trace.view_mut(), 0, T::align(&new_value), row_index);
}
MemorySlice::Next(..) => {
let mut trace = self.0.trace.write().unwrap();
let window = trace.window(row_index);
let parser = TraceWindowParser::new(window, &[], &[], &[]);
let value = data.eval(&parser);

let new_value = op(&value);
data.register()
.assign(&mut trace.view_mut(), 0, T::align(&new_value), row_index);
}
MemorySlice::Global(..) => {
let mut global = self.0.global.write().unwrap();
let value = data.read_from_slice(&global);
let new_value = op(&value);
data.assign_to_raw_slice(&mut global, &new_value);
}
MemorySlice::Public(..) => {
let mut public = self.0.public.write().unwrap();
let value = data.read_from_slice(&public);
let new_value = op(&value);
data.assign_to_raw_slice(&mut public, &new_value);
}
MemorySlice::Challenge(..) => unreachable!("Challenge registers are read-only"),
}
}
}

impl<T> Deref for TraceWriter<T> {
Expand Down
1 change: 0 additions & 1 deletion curta/src/chip/uint/bytes/gadget/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ impl<F: RichField + Extendable<D>, E: CubicParameters<F>, const D: usize> Simple
writer.write_row_instructions(&self.trace_generator.air_data, i);
}
writer.write_global_instructions(&self.trace_generator.air_data);
self.table.write_multiplicities(&writer);
}

fn serialize(
Expand Down
12 changes: 3 additions & 9 deletions curta/src/chip/uint/bytes/lookup_table/builder_operations.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
use alloc::sync::Arc;

use super::multiplicity_data::MultiplicityData;
use crate::chip::register::array::ArrayRegister;
use crate::chip::register::cubic::CubicRegister;
use crate::chip::register::element::ElementRegister;

#[derive(Debug, Clone)]
pub struct ByteLookupOperations {
pub multiplicity_data: Arc<MultiplicityData>,
pub row_acc_challenges: ArrayRegister<CubicRegister>,
pub values: Vec<CubicRegister>,
pub values: Vec<ElementRegister>,
}

impl ByteLookupOperations {
pub fn new(
multiplicity_data: Arc<MultiplicityData>,
row_acc_challenges: ArrayRegister<CubicRegister>,
) -> Self {
pub fn new(multiplicity_data: Arc<MultiplicityData>) -> Self {
let values = Vec::new();
ByteLookupOperations {
multiplicity_data,
row_acc_challenges,
values,
}
}
Expand Down
30 changes: 19 additions & 11 deletions curta/src/chip/uint/bytes/lookup_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::bit_operations::not::Not;
use super::bit_operations::xor::Xor;
use super::decode::ByteDecodeInstruction;
use super::operations::instruction::ByteOperationInstruction;
use super::operations::value::ByteOperationDigestConstraint;
use crate::air::parser::AirParser;
use crate::air::AirConstraint;
use crate::chip::bool::SelectInstruction;
Expand All @@ -16,7 +17,6 @@ use crate::chip::register::bit::BitRegister;
use crate::chip::register::cubic::CubicRegister;
use crate::chip::register::memory::MemorySlice;
use crate::chip::trace::writer::TraceWriter;
use crate::chip::uint::bytes::operations::NUM_CHALLENGES;
use crate::chip::AirParameters;

pub mod builder_operations;
Expand All @@ -33,13 +33,15 @@ pub enum ByteInstructionSet {
BitNot(Not<8>),
BitSelect(SelectInstruction<BitRegister>),
Decode(ByteDecodeInstruction),
Digest(ByteOperationDigestConstraint),
}

pub trait ByteInstructions:
From<ByteInstructionSet>
+ From<ByteOperationInstruction>
+ From<SelectInstruction<BitRegister>>
+ From<ByteDecodeInstruction>
+ From<ByteOperationDigestConstraint>
{
}

Expand All @@ -52,11 +54,8 @@ impl<L: AirParameters> AirBuilder<L> {
+ From<SelectInstruction<BitRegister>>
+ From<ByteDecodeInstruction>,
{
let row_acc_challenges = self.alloc_challenge_array::<CubicRegister>(NUM_CHALLENGES);

let lookup_table = self.new_byte_lookup_table(row_acc_challenges);
let operations =
ByteLookupOperations::new(lookup_table.multiplicity_data.clone(), row_acc_challenges);
let lookup_table = self.new_byte_lookup_table();
let operations = ByteLookupOperations::new(lookup_table.multiplicity_data.clone());

(operations, lookup_table)
}
Expand All @@ -76,7 +75,7 @@ impl<L: AirParameters> AirBuilder<L> {
);
let lookup_values = self.lookup_values(&lookup_challenge, &operation_values.values);

self.cubic_lookup_from_table_and_values(lookup_table, lookup_values);
self.element_lookup_from_table_and_values(lookup_table, lookup_values);
}
}

Expand All @@ -89,6 +88,7 @@ impl<AP: AirParser> AirConstraint<AP> for ByteInstructionSet {
Self::BitNot(op) => op.eval(parser),
Self::BitSelect(op) => op.eval(parser),
Self::Decode(instruction) => instruction.eval(parser),
Self::Digest(instruction) => instruction.eval(parser),
}
}
}
Expand All @@ -102,6 +102,7 @@ impl<F: PrimeField64> Instruction<F> for ByteInstructionSet {
Self::BitNot(op) => Instruction::<F>::inputs(op),
Self::BitSelect(op) => Instruction::<F>::inputs(op),
Self::Decode(instruction) => Instruction::<F>::inputs(instruction),
Self::Digest(instruction) => Instruction::<F>::inputs(instruction),
}
}

Expand All @@ -113,6 +114,7 @@ impl<F: PrimeField64> Instruction<F> for ByteInstructionSet {
Self::BitNot(op) => Instruction::<F>::trace_layout(op),
Self::BitSelect(op) => Instruction::<F>::trace_layout(op),
Self::Decode(instruction) => Instruction::<F>::trace_layout(instruction),
Self::Digest(instruction) => Instruction::<F>::trace_layout(instruction),
}
}

Expand All @@ -124,6 +126,7 @@ impl<F: PrimeField64> Instruction<F> for ByteInstructionSet {
Self::BitNot(op) => Instruction::<F>::write(op, writer, row_index),
Self::BitSelect(op) => Instruction::<F>::write(op, writer, row_index),
Self::Decode(instruction) => Instruction::<F>::write(instruction, writer, row_index),
Self::Digest(instruction) => Instruction::<F>::write(instruction, writer, row_index),
}
}
}
Expand Down Expand Up @@ -164,6 +167,12 @@ impl From<ByteDecodeInstruction> for ByteInstructionSet {
}
}

impl From<ByteOperationDigestConstraint> for ByteInstructionSet {
fn from(instruction: ByteOperationDigestConstraint) -> Self {
Self::Digest(instruction)
}
}

#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};
Expand All @@ -186,8 +195,8 @@ mod tests {

type Instruction = ByteInstructionSet;

const NUM_FREE_COLUMNS: usize = 281;
const EXTENDED_COLUMNS: usize = 447;
const NUM_FREE_COLUMNS: usize = 377;
const EXTENDED_COLUMNS: usize = 159;
const NUM_ARITHMETIC_COLUMNS: usize = 0;
}

Expand Down Expand Up @@ -327,7 +336,6 @@ mod tests {
);
a_not.assign_to_raw_slice(&mut public_write, &F::from_canonical_u8(!a_pub_val));
b_not.assign_to_raw_slice(&mut public_write, &F::from_canonical_u8(!b_pub_val));
let public_inputs = public_write.clone();
drop(public_write);

for i in 0..num_rows {
Expand Down Expand Up @@ -365,10 +373,10 @@ mod tests {
writer.write_row_instructions(&generator.air_data, i);
}
writer.write_global_instructions(&generator.air_data);
table.write_multiplicities(&writer);

let stark = Starky::new(air);
let config = SC::standard_fast_config(num_rows);
let public_inputs = writer.public.read().unwrap().clone();

// Generate proof and verify as a stark
test_starky(&stark, &config, &generator, &public_inputs);
Expand Down
46 changes: 3 additions & 43 deletions curta/src/chip/uint/bytes/lookup_table/multiplicity_data.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use core::sync::atomic::{AtomicUsize, Ordering};
use std::collections::HashMap;

use itertools::Itertools;
use plonky2_maybe_rayon::ParallelIterator;
use serde::{Deserialize, Serialize};

use crate::chip::register::array::ArrayRegister;
Expand All @@ -14,35 +12,16 @@ use crate::chip::uint::bytes::operations::{
OPCODE_XOR,
};
use crate::math::prelude::*;
use crate::maybe_rayon::*;

#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplicityValues(pub Vec<[AtomicUsize; NUM_BIT_OPPS + 1]>);

#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplicityData {
multiplicities: ArrayRegister<ElementRegister>,
pub multiplicities_values: MultiplicityValues,
operations_multipcitiy_dict: HashMap<ByteOperation<u8>, (usize, usize)>,
pub operations_dict: HashMap<usize, Vec<ByteOperation<u8>>>,
}

impl MultiplicityValues {
pub fn new(num_rows: usize) -> Self {
Self(
(0..num_rows)
.map(|_| core::array::from_fn(|_| AtomicUsize::new(0)))
.collect(),
)
}

pub fn update(&self, row: usize, col: usize) {
self.0[row][col].fetch_add(1, Ordering::Relaxed);
}
}

impl MultiplicityData {
pub fn new(num_rows: usize, multiplicities: ArrayRegister<ElementRegister>) -> Self {
pub fn new(multiplicities: ArrayRegister<ElementRegister>) -> Self {
let mut operations_multipcitiy_dict = HashMap::new();
let mut operations_dict = HashMap::new();
for (row_index, (a, b)) in (0..=u8::MAX).cartesian_product(0..=u8::MAX).enumerate() {
Expand All @@ -62,38 +41,19 @@ impl MultiplicityData {
}
operations_dict.insert(row_index, operations);
}
let multiplicity_values = MultiplicityValues::new(num_rows);

Self {
multiplicities,
multiplicities_values: multiplicity_values,
operations_dict,
operations_multipcitiy_dict,
}
}

pub fn update(&self, operation: &ByteOperation<u8>) {
pub fn update<F: Field>(&self, operation: &ByteOperation<u8>, writer: &TraceWriter<F>) {
let (row, col) = self.operations_multipcitiy_dict[operation];
self.multiplicities_values.update(row, col);
writer.fetch_and_modify(&self.multiplicities.get(col), |x| *x + F::ONE, row);
}

pub fn multiplicities(&self) -> &ArrayRegister<ElementRegister> {
&self.multiplicities
}

pub fn write_multiplicities<F: Field>(&self, writer: &TraceWriter<F>) {
let multiplicities_array = self.multiplicities;
writer
.write_trace()
.unwrap()
.rows_par_mut()
.zip_eq(self.multiplicities_values.0.par_iter().map(|arr| {
core::array::from_fn::<_, { NUM_BIT_OPPS + 1 }, _>(|i| {
F::from_canonical_usize(arr[i].load(Ordering::Relaxed))
})
}))
.for_each(|(row, multiplicities)| {
multiplicities_array.assign_to_raw_slice(row, &multiplicities);
});
}
}
Loading

0 comments on commit 022f376

Please sign in to comment.