-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from twiby/optim/karatsuba
karatsuba: diverse optimizations
- Loading branch information
Showing
6 changed files
with
542 additions
and
211 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,98 +1,143 @@ | ||
use crate::traits::Digit; | ||
use std::ops::Deref; | ||
use std::ops::DerefMut; | ||
|
||
use super::super::add_assign; | ||
use super::super::sub_assign; | ||
use super::schoolbook_add_assign_mul; | ||
use super::schoolbook_mul; | ||
|
||
// Below this number of digits, multiplication is schoolbook | ||
#[cfg(debug_assertions)] | ||
const KARATSUBA_INTERNAL_THRESHOLD: usize = 2; | ||
#[cfg(debug_assertions)] | ||
const KARATSUBA_EXTERNAL_THRESHOLD: usize = 2; | ||
const KARATSUBA_THRESHOLD: usize = 7; | ||
|
||
#[cfg(not(debug_assertions))] | ||
const KARATSUBA_INTERNAL_THRESHOLD: usize = 20; | ||
#[cfg(not(debug_assertions))] | ||
const KARATSUBA_EXTERNAL_THRESHOLD: usize = 156; | ||
|
||
pub(super) const KARATSUBA_EXTERNAL_THRESHOLD_SQUARED: usize = | ||
KARATSUBA_EXTERNAL_THRESHOLD * KARATSUBA_EXTERNAL_THRESHOLD; | ||
const KARATSUBA_THRESHOLD: usize = 25; | ||
|
||
pub(super) fn karatsuba<T: Digit>(rhs: &[T], lhs: &[T]) -> Vec<T> { | ||
let target_length = rhs.len().max(lhs.len()).next_power_of_two(); | ||
assert!(target_length < usize::MAX >> 1); | ||
fn allocate_buffer<T: Digit>(n: usize) -> impl Deref<Target = [T]> + DerefMut { | ||
vec![T::ZERO; n] | ||
} | ||
|
||
let mut x = rhs.to_vec(); | ||
let mut y = lhs.to_vec(); | ||
x.resize(target_length, T::ZERO); | ||
y.resize(target_length, T::ZERO); | ||
pub(super) fn karatsuba<T: Digit>(ret: &mut [T], rhs: &[T], lhs: &[T]) { | ||
if rhs.len() < lhs.len() { | ||
return karatsuba(ret, lhs, rhs); | ||
} | ||
|
||
let mut ret = vec![T::ZERO; target_length << 1]; | ||
let mut buff = vec![T::ZERO; target_length << 1]; | ||
_karatsuba::<KARATSUBA_INTERNAL_THRESHOLD, _>(&mut ret, &x, &y, &mut buff); | ||
ret.resize(rhs.len() + lhs.len(), T::ZERO); | ||
ret | ||
if rhs.len() == lhs.len() { | ||
let mut buff = allocate_buffer(lhs.len().next_power_of_two() << 1); | ||
symetric_karatsuba(ret, rhs, lhs, &mut buff); | ||
} else { | ||
let mut buff_1 = allocate_buffer(lhs.len() << 1); | ||
let mut buff_2 = allocate_buffer(lhs.len().next_power_of_two() << 1); | ||
asymetric_karatsuba(ret, rhs, lhs, &mut buff_1, &mut buff_2); | ||
} | ||
} | ||
fn _karatsuba<const THRESHOLD: usize, T: Digit>( | ||
ret: &mut [T], | ||
rhs: &[T], | ||
lhs: &[T], | ||
buff: &mut [T], | ||
|
||
/// multiplies big and small, puts the result in ret. | ||
/// | ||
/// we assume big is larger than small, and that ret is filled with zeros | ||
fn asymetric_karatsuba<'a, T: Digit>( | ||
mut ret: &mut [T], | ||
mut big: &'a [T], | ||
mut small: &'a [T], | ||
buff_1: &mut [T], | ||
buff_2: &mut [T], | ||
) { | ||
debug_assert!(rhs.len() == lhs.len()); | ||
debug_assert!(rhs.len().is_power_of_two()); | ||
debug_assert_eq!(ret.len(), 2 * rhs.len()); | ||
debug_assert_eq!(buff.len(), 2 * rhs.len()); | ||
debug_assert!(big.len() >= small.len()); | ||
let mut half_size = small.len(); | ||
let mut size = half_size << 1; | ||
let mut write_counter = 0; | ||
|
||
while !exit_karatsuba(small.len()) { | ||
symetric_karatsuba(&mut buff_1[..size], &big[..half_size], small, buff_2); | ||
|
||
if size > write_counter { | ||
ret[write_counter..size].copy_from_slice(&buff_1[write_counter..size]); | ||
add_assign(ret, &buff_1[..write_counter]); | ||
} else { | ||
add_assign(ret, &buff_1[..size]); | ||
} | ||
|
||
let size = rhs.len(); | ||
let half_size = size >> 1; | ||
big = &big[half_size..]; | ||
ret = &mut ret[half_size..]; | ||
write_counter = write_counter.max(size + 1) - half_size; | ||
|
||
if big.len() < small.len() { | ||
(small, big) = (big, small); | ||
half_size = small.len(); | ||
size = half_size << 1; | ||
} | ||
} | ||
|
||
if half_size > 0 { | ||
size = big.len() + small.len(); | ||
schoolbook_mul(&mut buff_1[..size], big, small); | ||
if size > write_counter { | ||
ret[write_counter..size].copy_from_slice(&buff_1[write_counter..size]); | ||
add_assign(ret, &buff_1[..write_counter]); | ||
} else { | ||
add_assign(ret, &buff_1[..size]); | ||
} | ||
} | ||
} | ||
|
||
#[inline] | ||
pub(crate) fn exit_karatsuba(size: usize) -> bool { | ||
size < KARATSUBA_THRESHOLD | ||
} | ||
|
||
/// multiplies big and small, puts the result in ret. | ||
/// | ||
/// we assume x and y have the same size | ||
/// ret doesn't have to be filled with zeros | ||
fn symetric_karatsuba<T: Digit>(ret: &mut [T], x: &[T], y: &[T], buff: &mut [T]) { | ||
// Early exit | ||
if size < THRESHOLD { | ||
schoolbook_add_assign_mul(ret, rhs, lhs); | ||
if exit_karatsuba(x.len()) { | ||
schoolbook_mul(ret, x, y); | ||
return; | ||
} | ||
|
||
let (x0, x1) = rhs.split_at(half_size); | ||
let (y0, y1) = lhs.split_at(half_size); | ||
|
||
// Compute (x0+x1) and (y0+y1), using ret as a buffer, | ||
// but specifically handle their last bit | ||
let (x_temp, y_temp) = ret[..size].split_at_mut(half_size); | ||
x_temp.copy_from_slice(x0); | ||
y_temp.copy_from_slice(y0); | ||
let x_carry = add_assign(x_temp, x1); | ||
let y_carry = add_assign(y_temp, y1); | ||
|
||
// compute z1 in a separate buffer | ||
// but specifically handle its last bit | ||
let (z1, new_buff) = buff.split_at_mut(size); | ||
let mut z1_last_bit = T::ZERO; | ||
_karatsuba::<THRESHOLD, _>(&mut z1[..size], x_temp, y_temp, new_buff); | ||
let size = x.len(); | ||
let half_size = (size >> 1) + (size % 2); | ||
let small_half_size = size >> 1; | ||
let size = half_size << 1; | ||
|
||
debug_assert_eq!(x.len(), y.len()); | ||
debug_assert_eq!(ret.len(), x.len() + y.len()); | ||
debug_assert!(buff.len() >= 2 * size); | ||
|
||
let (buff, sub_buff) = buff.split_at_mut(size); | ||
let (x0, x1) = x.split_at(half_size); | ||
let (y0, y1) = y.split_at(half_size); | ||
|
||
// Compute x0 + x1 and y0 + y1 in buff | ||
let (x_cross, y_cross) = buff.split_at_mut(half_size); | ||
x_cross.copy_from_slice(x0); | ||
y_cross.copy_from_slice(y0); | ||
let x_carry = add_assign(x_cross, x1); | ||
let y_carry = add_assign(y_cross, y1); | ||
|
||
// Compute z1 in ret | ||
let z1 = &mut ret[half_size..half_size + size + 2]; | ||
symetric_karatsuba(&mut z1[..size], x_cross, y_cross, sub_buff); | ||
z1[size] = T::from(x_carry && y_carry); | ||
if x_carry { | ||
z1_last_bit += T::from(add_assign(&mut z1[half_size..], &y_temp)); | ||
add_assign(&mut z1[half_size..], y_cross); | ||
} | ||
if y_carry { | ||
z1_last_bit += T::from(add_assign(&mut z1[half_size..], &x_temp)); | ||
} | ||
z1_last_bit += T::from(x_carry && y_carry); | ||
|
||
// z0 and z2 | ||
ret[..size].fill(T::ZERO); | ||
new_buff.fill(T::ZERO); | ||
_karatsuba::<THRESHOLD, _>(&mut ret[..size], x0, y0, new_buff); | ||
new_buff.fill(T::ZERO); | ||
_karatsuba::<THRESHOLD, _>(&mut ret[size..], x1, y1, new_buff); | ||
|
||
// subtract z0 and z2 from z1 | ||
if sub_assign(z1, &ret[..size]) { | ||
z1_last_bit -= T::ONE; | ||
} | ||
if sub_assign(z1, &ret[size..size * 2]) { | ||
z1_last_bit -= T::ONE; | ||
add_assign(&mut z1[half_size..], x_cross); | ||
} | ||
|
||
// add z1 | ||
add_assign(&mut ret[half_size..], z1); | ||
add_assign(&mut ret[half_size + size..], &[z1_last_bit]); | ||
// Compute z2 in buff | ||
let z2 = &mut buff[..2 * small_half_size]; | ||
symetric_karatsuba(z2, x1, y1, sub_buff); | ||
ret[half_size + size + 1..].copy_from_slice(&z2[half_size + 1..]); | ||
add_assign(&mut ret[size..], &z2[..half_size + 1]); | ||
sub_assign(&mut ret[half_size..], &z2); | ||
|
||
// Compute z0 in buff | ||
let z0 = &mut buff[..size]; | ||
symetric_karatsuba(z0, x0, y0, sub_buff); | ||
ret[..half_size].copy_from_slice(&z0[..half_size]); | ||
add_assign(&mut ret[half_size..], &z0[half_size..]); | ||
sub_assign(&mut ret[half_size..], &z0); | ||
} |
Oops, something went wrong.