Skip to content

Commit

Permalink
update msm wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
nonam3e committed Aug 26, 2024
1 parent f95e3b5 commit 7fd96b7
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 3 deletions.
18 changes: 15 additions & 3 deletions wrappers/rust/icicle-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ pub struct MSMConfig {
/// (better) upper bound is known, it should be reflected in this variable. Default value: 0 (set to the bitsize of scalar field).
pub bitsize: i32,

batch_size: i32,
are_bases_shared: bool,
pub batch_size: i32,
pub are_bases_shared: bool,
/// MSMs in batch share the bases. If false, expecting #bases==#scalars
are_scalars_on_device: bool,
pub are_scalars_montgomery_form: bool,
Expand Down Expand Up @@ -247,7 +247,7 @@ macro_rules! impl_msm {
unsafe {
$curve_prefix_ident::precompute_bases_ffi(
points.as_ptr(),
points.len() as i32,
points.len() as i32 / config.batch_size,
config,
output_bases.as_mut_ptr(),
)
Expand Down Expand Up @@ -281,6 +281,18 @@ macro_rules! impl_msm_tests {
check_msm_batch::<$curve>()
}

#[test]
fn test_msm_batch_shared() {
initialize();
check_msm_batch_shared::<$curve>()
}

#[test]
fn test_msm_batch_not_shared() {
initialize();
check_msm_batch_not_shared::<$curve>()
}

#[test]
fn test_msm_skewed_distributions() {
initialize();
Expand Down
167 changes: 167 additions & 0 deletions wrappers/rust/icicle-core/src/msm/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::curve::{Affine, Curve, Projective};
use crate::msm::{msm, precompute_bases, MSMConfig, CUDA_MSM_LARGE_BUCKET_FACTOR, MSM};
use crate::test_utilities;
use crate::traits::{FieldImpl, GenerateRandom, MontgomeryConvertible};
use icicle_runtime::memory::HostOrDeviceSlice;
use icicle_runtime::{
memory::{DeviceVec, HostSlice},
runtime,
Expand Down Expand Up @@ -178,6 +179,172 @@ where
.unwrap();
}

pub fn check_msm_batch_shared<C: Curve + MSM<C>>()
where
<C::ScalarField as FieldImpl>::Config: GenerateRandom<C::ScalarField>,
{
// let test_sizes = [1000, 1 << 16]; //TODO - uncomment this line after implementing fast msm
let test_sizes = [100];
// let batch_sizes = [1, 3, 1 << 4];
let batch_sizes = [1, 3]; //TODO - uncomment this line after implementing fast msm
let mut stream = IcicleStream::create().unwrap();
let precompute_factor = 8;
let mut cfg = MSMConfig::default();
cfg.stream_handle = *stream;
cfg.is_async = true;
cfg.ext
.set_int(CUDA_MSM_LARGE_BUCKET_FACTOR, 5);
cfg.c = 4;
runtime::warmup(&stream).unwrap();
stream
.synchronize()
.unwrap();
for test_size in test_sizes {
// (1) compute MSM with and w/o precompute on main device
test_utilities::test_set_main_device();
cfg.precompute_factor = precompute_factor;
let points = generate_random_affine_points_with_zeroes::<C>(test_size, 10);
let mut precomputed_points_d =
DeviceVec::<Affine<C>>::device_malloc(cfg.precompute_factor as usize * test_size).unwrap();
precompute_bases(HostSlice::from_slice(&points), &cfg, &mut precomputed_points_d).unwrap();
for batch_size in batch_sizes {
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size * batch_size);
// a version of batched msm without using `cfg.points_size`, requires copying bases
// let points_cloned: Vec<Affine<C>> = std::iter::repeat(points.clone())
// .take(batch_size)
// .flatten()
// .collect();
let scalars_h = HostSlice::from_slice(&scalars);

let mut msm_results_1 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut msm_results_2 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut points_d = DeviceVec::<Affine<C>>::device_malloc(test_size).unwrap();
points_d
.copy_from_host_async(HostSlice::from_slice(&points), &stream)
.unwrap();

cfg.precompute_factor = precompute_factor;
msm(scalars_h, &precomputed_points_d[..], &cfg, &mut msm_results_1[..]).unwrap();
cfg.precompute_factor = 1;
msm(scalars_h, &points_d[..], &cfg, &mut msm_results_2[..]).unwrap();

let mut msm_host_result_1 = vec![Projective::<C>::zero(); batch_size];
let mut msm_host_result_2 = vec![Projective::<C>::zero(); batch_size];
msm_results_1
.copy_to_host_async(HostSlice::from_mut_slice(&mut msm_host_result_1), &stream)
.unwrap();
msm_results_2
.copy_to_host_async(HostSlice::from_mut_slice(&mut msm_host_result_2), &stream)
.unwrap();
stream
.synchronize()
.unwrap();

// (2) compute on ref device and compare to both cases (with or w/o precompute)
test_utilities::test_set_ref_device();
let mut msm_ref_result = vec![Projective::<C>::zero(); batch_size];
let mut ref_msm_config = MSMConfig::default();
ref_msm_config.c = 4;
msm(
scalars_h,
HostSlice::from_slice(&points),
&MSMConfig::default(),
HostSlice::from_mut_slice(&mut msm_ref_result),
)
.unwrap();

assert_eq!(msm_host_result_1, msm_ref_result);
assert_eq!(msm_host_result_2, msm_ref_result);
}
}
stream
.destroy()
.unwrap();
}

pub fn check_msm_batch_not_shared<C: Curve + MSM<C>>()
where
<C::ScalarField as FieldImpl>::Config: GenerateRandom<C::ScalarField>,
{
// let test_sizes = [1000, 1 << 16]; //TODO - uncomment this line after implementing fast msm
let test_sizes = [100];
// let batch_sizes = [1, 3, 1 << 4];
let batch_sizes = [3, 5]; //TODO - uncomment this line after implementing fast msm
let mut stream = IcicleStream::create().unwrap();
let precompute_factor = 8;
let mut cfg = MSMConfig::default();
cfg.stream_handle = *stream;
cfg.is_async = true;
cfg.ext
.set_int(CUDA_MSM_LARGE_BUCKET_FACTOR, 5);
cfg.c = 4;
runtime::warmup(&stream).unwrap();
stream
.synchronize()
.unwrap();
for test_size in test_sizes {
// (1) compute MSM with and w/o precompute on main device
test_utilities::test_set_main_device();
for batch_size in batch_sizes {
cfg.precompute_factor = precompute_factor;
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size * batch_size);
let scalars_h = HostSlice::from_slice(&scalars);

let points = generate_random_affine_points_with_zeroes::<C>(test_size * batch_size, 10);
println!("points len: {}", points.len());
let mut precomputed_points_d =
DeviceVec::<Affine<C>>::device_malloc(cfg.precompute_factor as usize * test_size * batch_size).unwrap();
cfg.batch_size = batch_size as i32;
cfg.are_bases_shared = false;
precompute_bases(HostSlice::from_slice(&points), &cfg, &mut precomputed_points_d).unwrap();
println!("precomputed points len: {}", (precomputed_points_d).len());

let mut msm_results_1 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut msm_results_2 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut points_d = DeviceVec::<Affine<C>>::device_malloc(test_size * batch_size).unwrap();
points_d
.copy_from_host_async(HostSlice::from_slice(&points), &stream)
.unwrap();

cfg.precompute_factor = precompute_factor;
msm(scalars_h, &precomputed_points_d[..], &cfg, &mut msm_results_1[..]).unwrap();
cfg.precompute_factor = 1;
msm(scalars_h, &points_d[..], &cfg, &mut msm_results_2[..]).unwrap();

let mut msm_host_result_1 = vec![Projective::<C>::zero(); batch_size];
let mut msm_host_result_2 = vec![Projective::<C>::zero(); batch_size];
msm_results_1
.copy_to_host_async(HostSlice::from_mut_slice(&mut msm_host_result_1), &stream)
.unwrap();
msm_results_2
.copy_to_host_async(HostSlice::from_mut_slice(&mut msm_host_result_2), &stream)
.unwrap();
stream
.synchronize()
.unwrap();

// (2) compute on ref device and compare to both cases (with or w/o precompute)
test_utilities::test_set_ref_device();
let mut msm_ref_result = vec![Projective::<C>::zero(); batch_size];
let mut ref_msm_config = MSMConfig::default();
ref_msm_config.c = 4;
msm(
scalars_h,
HostSlice::from_slice(&points),
&MSMConfig::default(),
HostSlice::from_mut_slice(&mut msm_ref_result),
)
.unwrap();

assert_eq!(msm_host_result_1, msm_ref_result);
assert_eq!(msm_host_result_2, msm_ref_result);
}
}
stream
.destroy()
.unwrap();
}

pub fn check_msm_skewed_distributions<C: Curve + MSM<C>>()
where
<C::ScalarField as FieldImpl>::Config: GenerateRandom<C::ScalarField>,
Expand Down

0 comments on commit 7fd96b7

Please sign in to comment.