Skip to content

Commit

Permalink
chore: const filtering optimizations (#825)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Jul 12, 2024
1 parent 6855ea1 commit 8197340
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 137 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions examples/onnx/lenet_5/input.json

Large diffs are not rendered by default.

Binary file added examples/onnx/lenet_5/network.onnx
Binary file not shown.
116 changes: 53 additions & 63 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,35 +250,35 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}

region.flush()?;
// time this entire function run
let global_start = instant::Instant::now();

let mut values = values.clone();

// this section has been optimized to death, don't mess with it
let mut removal_indices = values[0].get_const_zero_indices()?;
let second_zero_indices = values[1].get_const_zero_indices()?;
let mut removal_indices = values[0].get_const_zero_indices();
let second_zero_indices = values[1].get_const_zero_indices();
removal_indices.extend(second_zero_indices);
removal_indices.par_sort_unstable();
removal_indices.dedup();

// if empty return a const
if removal_indices.len() == values[0].len() {
return Ok(create_zero_tensor(1));
}

// is already sorted
values[0].remove_indices(&mut removal_indices, true)?;
values[1].remove_indices(&mut removal_indices, true)?;

let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);

if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}

// if empty return a const
if values[0].is_empty() && values[1].is_empty() {
return Ok(create_zero_tensor(1));
}

let start = instant::Instant::now();
let mut inputs = vec![];
let block_width = config.custom_gates.output.num_inner_cols();
Expand Down Expand Up @@ -343,7 +343,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
.collect::<Result<Vec<_>, CircuitError>>()?;
}

let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
let last_elem = output.last()?;

region.increment(assigned_len);

Expand Down Expand Up @@ -1779,12 +1779,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
let mut values = values.clone();

// this section has been optimized to death, don't mess with it
let mut removal_indices = values[0].get_const_zero_indices()?;
removal_indices.par_sort_unstable();
removal_indices.dedup();

// is already sorted
values[0].remove_indices(&mut removal_indices, true)?;
values[0].remove_const_zero_values();

let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);
Expand Down Expand Up @@ -1841,7 +1836,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
}
}

let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
let last_elem = output.last()?;

region.increment(assigned_len);

Expand Down Expand Up @@ -1884,7 +1879,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
let global_start = instant::Instant::now();

// this section has been optimized to death, don't mess with it
let removal_indices = values[0].get_const_zero_indices()?;
let removal_indices = values[0].get_const_zero_indices();

let elapsed = global_start.elapsed();
trace!("finding const zero indices took: {:?}", elapsed);
Expand Down Expand Up @@ -1945,7 +1940,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
.collect::<Result<Vec<_>, CircuitError>>()?;
}

let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
let last_elem = output.last()?;

region.increment(assigned_len);

Expand Down Expand Up @@ -2256,22 +2251,22 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let orig_lhs = lhs.clone();
let orig_rhs = rhs.clone();

// get indices of zeros
let first_zero_indices = lhs.get_const_zero_indices()?;
let second_zero_indices = rhs.get_const_zero_indices()?;
let mut removal_indices = match op {
let start = instant::Instant::now();
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());

let removal_indices = match op {
BaseOp::Add | BaseOp::Mult => {
let mut removal_indices = first_zero_indices.clone();
removal_indices.extend(second_zero_indices.clone());
removal_indices
// join the zero indices
first_zero_indices
.union(&second_zero_indices)
.cloned()
.collect()
}
BaseOp::Sub => second_zero_indices.clone(),
_ => return Err(CircuitError::UnsupportedOp),
};
removal_indices.dedup();

let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
let removal_indices_ptr = &removal_indices;
trace!("setting up indices took {:?}", start.elapsed());

if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
Expand All @@ -2280,20 +2275,19 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
)));
}

let mut inputs = vec![];
for (i, input) in [lhs.clone(), rhs.clone()].iter().enumerate() {
let inp = {
let inputs = [lhs.clone(), rhs.clone()]
.iter()
.enumerate()
.map(|(i, input)| {
let res = region.assign_with_omissions(
&config.custom_gates.inputs[i],
input,
removal_indices_ptr,
&removal_indices,
)?;

res.get_inner()?
};

inputs.push(inp);
}
Ok(res.get_inner()?)
})
.collect::<Result<Vec<_>, CircuitError>>()?;

// Now we can assign the dot product
// time the calc
Expand All @@ -2308,15 +2302,20 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let elapsed = start.elapsed();
trace!("pairwise {} calc took {:?}", op.as_str(), start.elapsed());

let assigned_len = inputs[0].len() - removal_indices.len();
let start = instant::Instant::now();
let assigned_len = op_result.len() - removal_indices.len();
let mut output = region.assign_with_omissions(
&config.custom_gates.output,
&op_result.into(),
removal_indices_ptr,
&removal_indices,
)?;
trace!("pairwise {} calc took {:?}", op.as_str(), elapsed);
trace!(
"pairwise {} input assign took {:?}",
op.as_str(),
start.elapsed()
);

// Enable the selectors
if !region.is_dummy() {
Expand All @@ -2337,16 +2336,11 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let a_tensor = orig_lhs.get_inner_tensor()?;
let b_tensor = orig_rhs.get_inner_tensor()?;

let first_zero_indices: HashSet<&usize> = HashSet::from_iter(first_zero_indices.iter());
let second_zero_indices: HashSet<&usize> = HashSet::from_iter(second_zero_indices.iter());

trace!("setting up indices took {:?}", start.elapsed());

// infill the zero indices with the correct values from values[0] or values[1]
if !removal_indices_ptr.is_empty() {
if !removal_indices.is_empty() {
output
.get_inner_tensor_mut()?
.par_enum_map_mut_filtered(removal_indices_ptr, |i| {
.par_enum_map_mut_filtered(&removal_indices, |i| {
let val = match op {
BaseOp::Add => {
let a_is_null = first_zero_indices.contains(&i);
Expand Down Expand Up @@ -2386,6 +2380,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
end,
region.row()
);
trace!("----------------------------");

Ok(output)
}
Expand Down Expand Up @@ -3777,7 +3772,7 @@ pub(crate) fn boolean_identity<
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, CircuitError> {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
let output = if assign || !values[0].get_const_indices().is_empty() {
// get zero constants indices
let output = region.assign(&config.custom_gates.output, &values[0])?;
region.increment(output.len());
Expand Down Expand Up @@ -3942,11 +3937,10 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::

let x = values[0].clone();

let removal_indices = values[0].get_const_indices()?;
let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
let removal_indices_ptr = &removal_indices;
let removal_indices = values[0].get_const_indices();
let removal_indices: HashSet<usize> = HashSet::from_iter(removal_indices);

let w = region.assign_with_omissions(&config.static_lookups.input, &x, removal_indices_ptr)?;
let w = region.assign_with_omissions(&config.static_lookups.input, &x, &removal_indices)?;

let output = w.get_inner_tensor()?.par_enum_map(|i, e| {
Ok::<_, TensorError>(if let Some(f) = e.get_felt_eval() {
Expand All @@ -3964,7 +3958,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
let mut output = region.assign_with_omissions(
&config.static_lookups.output,
&output.into(),
removal_indices_ptr,
&removal_indices,
)?;

let is_dummy = region.is_dummy();
Expand Down Expand Up @@ -3994,11 +3988,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
})?
.into();

region.assign_with_omissions(
&config.static_lookups.index,
&table_index,
removal_indices_ptr,
)?;
region.assign_with_omissions(&config.static_lookups.index, &table_index, &removal_indices)?;

if !is_dummy {
(0..assigned_len)
Expand Down
41 changes: 20 additions & 21 deletions src/circuit/ops/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use portable_atomic::AtomicI64 as AtomicInt;
use std::{
cell::RefCell,
Expand Down Expand Up @@ -515,18 +517,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign(
Ok(var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
Expand All @@ -542,18 +544,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign(
Ok(var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
Expand All @@ -564,7 +566,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
self.assign_dynamic_lookup(var, values)
}

Expand All @@ -573,27 +575,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<&usize>,
) -> Result<ValTensor<F>, Error> {
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign_with_omissions(
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
)?)
} else {
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;

for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
let values_map = values.create_constants_map();

self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);

Ok(values.clone())
}
Expand Down
Loading

0 comments on commit 8197340

Please sign in to comment.