Skip to content

Commit

Permalink
Merge pull request #8 from twiby/optim/karatsuba
Browse files Browse the repository at this point in the history
karatsuba: diverse optimizations
  • Loading branch information
twiby authored Aug 13, 2024
2 parents a1ce2b2 + d7ad4c1 commit 6743c3f
Show file tree
Hide file tree
Showing 6 changed files with 542 additions and 211 deletions.
89 changes: 0 additions & 89 deletions benches/num-bigint.rs

This file was deleted.

58 changes: 58 additions & 0 deletions src/biguint/ops/implem_choices/add/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,62 @@ mod tests {
add_assign(&mut temp, &b[1..]);
assert_eq!(temp, ret_part);
}

/// Randomize some tests to compare the result with num-bigint
#[cfg(feature = "rand")]
fn coherence_with_num_bigint(n: usize) {
use num_bigint::BigUint;
use rand::distributions::Standard;
use rand::prelude::Distribution;
use rand::{thread_rng, Rng};

fn gen_n_random_values<T>(n: usize) -> Vec<T>
where
Standard: Distribution<T>,
{
let mut ret = Vec::<T>::with_capacity(n);
for _ in 0..n {
ret.push(thread_rng().gen::<T>());
}
ret
}

println!("STEP {n}");

const SIZE: usize = 1000;
let vec_a = gen_n_random_values::<u32>(SIZE);
let vec_b = gen_n_random_values::<u32>(SIZE);

let a = BigUint::new(vec_a.clone());
let b = BigUint::new(vec_b.clone());
let c = &a + &b;
let should_get = c.to_u32_digits();

let mut got = vec_a.clone();
if add_assign(&mut got, &vec_b) {
got.push(1);
}

if should_get != got {
assert_eq!(should_get.len(), got.len());
for (i, (a, b)) in should_get.iter().zip(got.iter()).enumerate() {
if a > b {
println!("digit {i}, diff {}", a - b);
} else if b > a {
println!("digit {i}, diff {}", b - a);
}
}
}

assert_eq!(should_get, got);
}

/// Randomize some tests to compare the result with num-bigint
#[test]
#[cfg(feature = "rand")]
fn coherence_with_num_bigint_many() {
for n in 0..100 {
coherence_with_num_bigint(n);
}
}
}
189 changes: 117 additions & 72 deletions src/biguint/ops/implem_choices/mul/karatsuba.rs
Original file line number Diff line number Diff line change
@@ -1,98 +1,143 @@
use crate::traits::Digit;
use std::ops::Deref;
use std::ops::DerefMut;

use super::super::add_assign;
use super::super::sub_assign;
use super::schoolbook_add_assign_mul;
use super::schoolbook_mul;

// Below this number of digits, multiplication is schoolbook
#[cfg(debug_assertions)]
const KARATSUBA_INTERNAL_THRESHOLD: usize = 2;
#[cfg(debug_assertions)]
const KARATSUBA_EXTERNAL_THRESHOLD: usize = 2;
const KARATSUBA_THRESHOLD: usize = 7;

#[cfg(not(debug_assertions))]
const KARATSUBA_INTERNAL_THRESHOLD: usize = 20;
#[cfg(not(debug_assertions))]
const KARATSUBA_EXTERNAL_THRESHOLD: usize = 156;

pub(super) const KARATSUBA_EXTERNAL_THRESHOLD_SQUARED: usize =
KARATSUBA_EXTERNAL_THRESHOLD * KARATSUBA_EXTERNAL_THRESHOLD;
const KARATSUBA_THRESHOLD: usize = 25;

pub(super) fn karatsuba<T: Digit>(rhs: &[T], lhs: &[T]) -> Vec<T> {
let target_length = rhs.len().max(lhs.len()).next_power_of_two();
assert!(target_length < usize::MAX >> 1);
fn allocate_buffer<T: Digit>(n: usize) -> impl Deref<Target = [T]> + DerefMut {
vec![T::ZERO; n]
}

let mut x = rhs.to_vec();
let mut y = lhs.to_vec();
x.resize(target_length, T::ZERO);
y.resize(target_length, T::ZERO);
pub(super) fn karatsuba<T: Digit>(ret: &mut [T], rhs: &[T], lhs: &[T]) {
if rhs.len() < lhs.len() {
return karatsuba(ret, lhs, rhs);
}

let mut ret = vec![T::ZERO; target_length << 1];
let mut buff = vec![T::ZERO; target_length << 1];
_karatsuba::<KARATSUBA_INTERNAL_THRESHOLD, _>(&mut ret, &x, &y, &mut buff);
ret.resize(rhs.len() + lhs.len(), T::ZERO);
ret
if rhs.len() == lhs.len() {
let mut buff = allocate_buffer(lhs.len().next_power_of_two() << 1);
symetric_karatsuba(ret, rhs, lhs, &mut buff);
} else {
let mut buff_1 = allocate_buffer(lhs.len() << 1);
let mut buff_2 = allocate_buffer(lhs.len().next_power_of_two() << 1);
asymetric_karatsuba(ret, rhs, lhs, &mut buff_1, &mut buff_2);
}
}
fn _karatsuba<const THRESHOLD: usize, T: Digit>(
ret: &mut [T],
rhs: &[T],
lhs: &[T],
buff: &mut [T],

/// multiplies big and small, puts the result in ret.
///
/// we assume big is larger than small, and that ret is filled with zeros
fn asymetric_karatsuba<'a, T: Digit>(
mut ret: &mut [T],
mut big: &'a [T],
mut small: &'a [T],
buff_1: &mut [T],
buff_2: &mut [T],
) {
debug_assert!(rhs.len() == lhs.len());
debug_assert!(rhs.len().is_power_of_two());
debug_assert_eq!(ret.len(), 2 * rhs.len());
debug_assert_eq!(buff.len(), 2 * rhs.len());
debug_assert!(big.len() >= small.len());
let mut half_size = small.len();
let mut size = half_size << 1;
let mut write_counter = 0;

while !exit_karatsuba(small.len()) {
symetric_karatsuba(&mut buff_1[..size], &big[..half_size], small, buff_2);

if size > write_counter {
ret[write_counter..size].copy_from_slice(&buff_1[write_counter..size]);
add_assign(ret, &buff_1[..write_counter]);
} else {
add_assign(ret, &buff_1[..size]);
}

let size = rhs.len();
let half_size = size >> 1;
big = &big[half_size..];
ret = &mut ret[half_size..];
write_counter = write_counter.max(size + 1) - half_size;

if big.len() < small.len() {
(small, big) = (big, small);
half_size = small.len();
size = half_size << 1;
}
}

if half_size > 0 {
size = big.len() + small.len();
schoolbook_mul(&mut buff_1[..size], big, small);
if size > write_counter {
ret[write_counter..size].copy_from_slice(&buff_1[write_counter..size]);
add_assign(ret, &buff_1[..write_counter]);
} else {
add_assign(ret, &buff_1[..size]);
}
}
}

#[inline]
pub(crate) fn exit_karatsuba(size: usize) -> bool {
size < KARATSUBA_THRESHOLD
}

/// multiplies big and small, puts the result in ret.
///
/// we assume x and y have the same size
/// ret doesn't have to be filled with zeros
fn symetric_karatsuba<T: Digit>(ret: &mut [T], x: &[T], y: &[T], buff: &mut [T]) {
// Early exit
if size < THRESHOLD {
schoolbook_add_assign_mul(ret, rhs, lhs);
if exit_karatsuba(x.len()) {
schoolbook_mul(ret, x, y);
return;
}

let (x0, x1) = rhs.split_at(half_size);
let (y0, y1) = lhs.split_at(half_size);

// Compute (x0+x1) and (y0+y1), using ret as a buffer,
// but specifically handle their last bit
let (x_temp, y_temp) = ret[..size].split_at_mut(half_size);
x_temp.copy_from_slice(x0);
y_temp.copy_from_slice(y0);
let x_carry = add_assign(x_temp, x1);
let y_carry = add_assign(y_temp, y1);

// compute z1 in a separate buffer
// but specifically handle its last bit
let (z1, new_buff) = buff.split_at_mut(size);
let mut z1_last_bit = T::ZERO;
_karatsuba::<THRESHOLD, _>(&mut z1[..size], x_temp, y_temp, new_buff);
let size = x.len();
let half_size = (size >> 1) + (size % 2);
let small_half_size = size >> 1;
let size = half_size << 1;

debug_assert_eq!(x.len(), y.len());
debug_assert_eq!(ret.len(), x.len() + y.len());
debug_assert!(buff.len() >= 2 * size);

let (buff, sub_buff) = buff.split_at_mut(size);
let (x0, x1) = x.split_at(half_size);
let (y0, y1) = y.split_at(half_size);

// Compute x0 + x1 and y0 + y1 in buff
let (x_cross, y_cross) = buff.split_at_mut(half_size);
x_cross.copy_from_slice(x0);
y_cross.copy_from_slice(y0);
let x_carry = add_assign(x_cross, x1);
let y_carry = add_assign(y_cross, y1);

// Compute z1 in ret
let z1 = &mut ret[half_size..half_size + size + 2];
symetric_karatsuba(&mut z1[..size], x_cross, y_cross, sub_buff);
z1[size] = T::from(x_carry && y_carry);
if x_carry {
z1_last_bit += T::from(add_assign(&mut z1[half_size..], &y_temp));
add_assign(&mut z1[half_size..], y_cross);
}
if y_carry {
z1_last_bit += T::from(add_assign(&mut z1[half_size..], &x_temp));
}
z1_last_bit += T::from(x_carry && y_carry);

// z0 and z2
ret[..size].fill(T::ZERO);
new_buff.fill(T::ZERO);
_karatsuba::<THRESHOLD, _>(&mut ret[..size], x0, y0, new_buff);
new_buff.fill(T::ZERO);
_karatsuba::<THRESHOLD, _>(&mut ret[size..], x1, y1, new_buff);

// subtract z0 and z2 from z1
if sub_assign(z1, &ret[..size]) {
z1_last_bit -= T::ONE;
}
if sub_assign(z1, &ret[size..size * 2]) {
z1_last_bit -= T::ONE;
add_assign(&mut z1[half_size..], x_cross);
}

// add z1
add_assign(&mut ret[half_size..], z1);
add_assign(&mut ret[half_size + size..], &[z1_last_bit]);
// Compute z2 in buff
let z2 = &mut buff[..2 * small_half_size];
symetric_karatsuba(z2, x1, y1, sub_buff);
ret[half_size + size + 1..].copy_from_slice(&z2[half_size + 1..]);
add_assign(&mut ret[size..], &z2[..half_size + 1]);
sub_assign(&mut ret[half_size..], &z2);

// Compute z0 in buff
let z0 = &mut buff[..size];
symetric_karatsuba(z0, x0, y0, sub_buff);
ret[..half_size].copy_from_slice(&z0[..half_size]);
add_assign(&mut ret[half_size..], &z0[half_size..]);
sub_assign(&mut ret[half_size..], &z0);
}
Loading

0 comments on commit 6743c3f

Please sign in to comment.