Skip to content

Commit

Permalink
[zk-token-sdk] Allow discrete log to be executed in the current thread (
Browse files Browse the repository at this point in the history
  • Loading branch information
samkim-crypto authored Mar 29, 2024
1 parent c5b9196 commit fb1ee78
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 55 deletions.
111 changes: 57 additions & 54 deletions zk-token-sdk/src/encryption/discrete_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use {
},
itertools::Itertools,
serde::{Deserialize, Serialize},
std::collections::HashMap,
std::{collections::HashMap, num::NonZeroUsize},
thiserror::Error,
};

Expand Down Expand Up @@ -57,14 +57,14 @@ pub struct DiscreteLog {
/// Target point for discrete log
pub target: RistrettoPoint,
/// Number of threads used for discrete log computation
num_threads: usize,
num_threads: Option<NonZeroUsize>,
/// Range bound for discrete log search derived from the max value to search for and
/// `num_threads`
range_bound: usize,
range_bound: NonZeroUsize,
/// Ristretto point representing each step of the discrete log search
step_point: RistrettoPoint,
/// Ristretto point compression batch size
compression_batch_size: usize,
compression_batch_size: NonZeroUsize,
}

#[derive(Serialize, Deserialize, Default)]
Expand Down Expand Up @@ -107,34 +107,37 @@ impl DiscreteLog {
Self {
generator,
target,
num_threads: 1,
range_bound: TWO16 as usize,
num_threads: None,
range_bound: (TWO16 as usize).try_into().unwrap(),
step_point: G,
compression_batch_size: 32,
compression_batch_size: 32.try_into().unwrap(),
}
}

/// Adjusts number of threads in a discrete log instance.
#[cfg(not(target_arch = "wasm32"))]
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
// number of threads must be a positive power-of-two integer
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD {
if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
return Err(DiscreteLogError::DiscreteLogThreads);
}

self.num_threads = num_threads;
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
self.step_point = Scalar::from(num_threads as u64) * G;
self.num_threads = Some(num_threads);
self.range_bound = (TWO16 as usize)
.checked_div(num_threads.get())
.and_then(|range_bound| range_bound.try_into().ok())
.unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
self.step_point = Scalar::from(num_threads.get() as u64) * G;

Ok(())
}

/// Adjusts inversion batch size in a discrete log instance.
pub fn set_compression_batch_size(
&mut self,
compression_batch_size: usize,
compression_batch_size: NonZeroUsize,
) -> Result<(), DiscreteLogError> {
if compression_batch_size >= TWO16 as usize || compression_batch_size == 0 {
if compression_batch_size.get() >= TWO16 as usize {
return Err(DiscreteLogError::DiscreteLogBatchSize);
}
self.compression_batch_size = compression_batch_size;
Expand All @@ -145,41 +148,41 @@ impl DiscreteLog {
/// Solves the discrete log problem under the assumption that the solution
/// is a positive 32-bit number.
pub fn decode_u32(self) -> Option<u64> {
#[cfg(not(target_arch = "wasm32"))]
{
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);

let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});

starting_point -= G;
handle
})
.collect::<Vec<_>>();

handles
.into_iter()
.map_while(|h| h.join().ok())
.find(|x| x.is_some())
.flatten()
}
#[cfg(target_arch = "wasm32")]
{
let ristretto_iterator = RistrettoIterator::new(
(self.target, 0_u64),
(-(&self.step_point), self.num_threads as u64),
);
if let Some(num_threads) = self.num_threads {
#[cfg(not(target_arch = "wasm32"))]
{
let mut starting_point = self.target;
let handles = (0..num_threads.get())
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), num_threads.get() as u64),
);

let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});

starting_point -= G;
handle
})
.collect::<Vec<_>>();

handles
.into_iter()
.map_while(|h| h.join().ok())
.find(|x| x.is_some())
.flatten()
}
#[cfg(target_arch = "wasm32")]
unreachable!() // `self.num_threads` always `None` on wasm target
} else {
let ristretto_iterator =
RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));

Self::decode_range(
ristretto_iterator,
Expand All @@ -191,15 +194,15 @@ impl DiscreteLog {

fn decode_range(
ristretto_iterator: RistrettoIterator,
range_bound: usize,
compression_batch_size: usize,
range_bound: NonZeroUsize,
compression_batch_size: NonZeroUsize,
) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None;

for batch in &ristretto_iterator
.take(range_bound)
.chunks(compression_batch_size)
.take(range_bound.get())
.chunks(compression_batch_size.get())
{
// batch compression currently errors if any point in the batch is the identity point
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
Expand Down Expand Up @@ -298,7 +301,7 @@ mod tests {
let amount: u64 = 55;

let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
instance.num_threads(4).unwrap();
instance.num_threads(4.try_into().unwrap()).unwrap();

// Very informal measurements for now
let start_computation = Instant::now();
Expand Down
2 changes: 1 addition & 1 deletion zk-token-sdk/src/encryption/elgamal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ mod tests {
let ciphertext = ElGamal::encrypt(&public, amount);

let mut instance = ElGamal::decrypt(&secret, &ciphertext);
instance.num_threads(4).unwrap();
instance.num_threads(4.try_into().unwrap()).unwrap();
assert_eq!(57_u64, instance.decode_u32().unwrap());
}

Expand Down

0 comments on commit fb1ee78

Please sign in to comment.