Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update msm wrappers #587

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions wrappers/rust/icicle-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ pub struct MSMConfig {
/// (better) upper bound is known, it should be reflected in this variable. Default value: 0 (set to the bitsize of scalar field).
pub bitsize: i32,

batch_size: i32,
are_points_shared_in_batch: bool,
pub batch_size: i32,
pub are_points_shared_in_batch: bool,
/// MSMs in batch share the bases. If false, expecting #bases==#scalars
are_scalars_on_device: bool,
pub are_scalars_montgomery_form: bool,
Expand Down Expand Up @@ -247,7 +247,7 @@ macro_rules! impl_msm {
unsafe {
$curve_prefix_ident::precompute_bases_ffi(
points.as_ptr(),
points.len() as i32,
points.len() as i32 / config.batch_size,
config,
output_bases.as_mut_ptr(),
)
Expand Down Expand Up @@ -276,9 +276,15 @@ macro_rules! impl_msm_tests {
}

#[test]
fn test_msm_batch() {
fn test_msm_batch_shared() {
initialize();
check_msm_batch::<$curve>()
check_msm_batch_shared::<$curve>()
}

#[test]
fn test_msm_batch_not_shared() {
initialize();
check_msm_batch_not_shared::<$curve>()
}

#[test]
Expand Down
91 changes: 84 additions & 7 deletions wrappers/rust/icicle-core/src/msm/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::curve::{Affine, Curve, Projective};
use crate::msm::{msm, precompute_bases, MSMConfig, CUDA_MSM_LARGE_BUCKET_FACTOR, MSM};
use crate::test_utilities;
use crate::traits::{FieldImpl, GenerateRandom, MontgomeryConvertible};
use icicle_runtime::memory::HostOrDeviceSlice;
use icicle_runtime::{
memory::{DeviceVec, HostSlice},
runtime,
Expand Down Expand Up @@ -94,7 +95,7 @@ where
});
}

pub fn check_msm_batch<C: Curve + MSM<C>>()
pub fn check_msm_batch_shared<C: Curve + MSM<C>>()
where
<C::ScalarField as FieldImpl>::Config: GenerateRandom<C::ScalarField>,
{
Expand Down Expand Up @@ -122,18 +123,94 @@ where
precompute_bases(HostSlice::from_slice(&points), &cfg, &mut precomputed_points_d).unwrap();
for batch_size in batch_sizes {
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size * batch_size);
// a version of batched msm without using `cfg.points_size`, requires copying bases
let points_cloned: Vec<Affine<C>> = std::iter::repeat(points.clone())
.take(batch_size)
.flatten()
.collect();
let scalars_h = HostSlice::from_slice(&scalars);

let mut msm_results_1 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut msm_results_2 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut points_d = DeviceVec::<Affine<C>>::device_malloc(test_size).unwrap();
points_d
.copy_from_host_async(HostSlice::from_slice(&points), &stream)
.unwrap();

cfg.precompute_factor = precompute_factor;
msm(scalars_h, &precomputed_points_d[..], &cfg, &mut msm_results_1[..]).unwrap();
cfg.precompute_factor = 1;
msm(scalars_h, &points_d[..], &cfg, &mut msm_results_2[..]).unwrap();

let mut msm_host_result_1 = vec![Projective::<C>::zero(); batch_size];
let mut msm_host_result_2 = vec![Projective::<C>::zero(); batch_size];
msm_results_1
.copy_to_host_async(HostSlice::from_mut_slice(&mut msm_host_result_1), &stream)
.unwrap();
msm_results_2
.copy_to_host_async(HostSlice::from_mut_slice(&mut msm_host_result_2), &stream)
.unwrap();
stream
.synchronize()
.unwrap();

// (2) compute on ref device and compare to both cases (with or w/o precompute)
test_utilities::test_set_ref_device();
let mut msm_ref_result = vec![Projective::<C>::zero(); batch_size];
let mut ref_msm_config = MSMConfig::default();
ref_msm_config.c = 4;
msm(
scalars_h,
HostSlice::from_slice(&points),
&MSMConfig::default(),
HostSlice::from_mut_slice(&mut msm_ref_result),
)
.unwrap();

assert_eq!(msm_host_result_1, msm_ref_result);
assert_eq!(msm_host_result_2, msm_ref_result);
}
}
stream
.destroy()
.unwrap();
}

pub fn check_msm_batch_not_shared<C: Curve + MSM<C>>()
where
<C::ScalarField as FieldImpl>::Config: GenerateRandom<C::ScalarField>,
{
let test_sizes = [1000, 1 << 16];
let batch_sizes = [1, 3, 1 << 4];
let mut stream = IcicleStream::create().unwrap();
let precompute_factor = 8;
let mut cfg = MSMConfig::default();
cfg.stream_handle = *stream;
cfg.is_async = true;
cfg.ext
.set_int(CUDA_MSM_LARGE_BUCKET_FACTOR, 5);
cfg.c = 4;
runtime::warmup(&stream).unwrap();
stream
.synchronize()
.unwrap();
for test_size in test_sizes {
// (1) compute MSM with and w/o precompute on main device
test_utilities::test_set_main_device();
for batch_size in batch_sizes {
cfg.precompute_factor = precompute_factor;
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size * batch_size);
let scalars_h = HostSlice::from_slice(&scalars);

let points = generate_random_affine_points_with_zeroes::<C>(test_size * batch_size, 10);
println!("points len: {}", points.len());
let mut precomputed_points_d =
DeviceVec::<Affine<C>>::device_malloc(cfg.precompute_factor as usize * test_size * batch_size).unwrap();
cfg.batch_size = batch_size as i32;
cfg.are_points_shared_in_batch = false;
precompute_bases(HostSlice::from_slice(&points), &cfg, &mut precomputed_points_d).unwrap();
println!("precomputed points len: {}", (precomputed_points_d).len());

let mut msm_results_1 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut msm_results_2 = DeviceVec::<Projective<C>>::device_malloc(batch_size).unwrap();
let mut points_d = DeviceVec::<Affine<C>>::device_malloc(test_size * batch_size).unwrap();
points_d
.copy_from_host_async(HostSlice::from_slice(&points_cloned), &stream)
.copy_from_host_async(HostSlice::from_slice(&points), &stream)
.unwrap();

cfg.precompute_factor = precompute_factor;
Expand Down
Loading