Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chunked-msm V3 #562

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/docs/icicle/primitives/msm.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct MSMConfig {
int c;
int bitsize;
int batch_size;
bool are_bases_shared;
bool are_points_shared_in_batch;
bool are_scalars_on_device;
bool are_scalars_montgomery_form;
bool are_points_on_device;
Expand All @@ -57,7 +57,7 @@ You can obtain a default `MSMConfig` using:
0, // c
0, // bitsize
1, // batch_size
true, // are_bases_shared
true, // are_points_shared_in_batch
false, // are_scalars_on_device
false, // are_scalars_montgomery_form
false, // are_points_on_device
Expand Down Expand Up @@ -87,7 +87,7 @@ The API is template and can work with all ICICLE curves (if corresponding lib is

The MSM supports batch mode - running multiple MSMs in parallel. It's always better to use the batch mode instead of running single msms in serial as long as there is enough memory available. We support running a batch of MSMs that share the same points as well as a batch of MSMs that use different points.

Config fields `are_bases_shared` and `batch_size` are used to configure msm for batch mode.
Config fields `are_points_shared_in_batch` and `batch_size` are used to configure msm for batch mode.

### G2 MSM

Expand Down
4 changes: 2 additions & 2 deletions docs/docs/icicle/rust-bindings/msm.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct MSMConfig {
pub c: i32,
pub bitsize: i32,
batch_size: i32,
are_bases_shared: bool,
are_points_shared_in_batch: bool,
are_scalars_on_device: bool,
pub are_scalars_montgomery_form: bool,
are_points_on_device: bool,
Expand Down Expand Up @@ -86,7 +86,7 @@ fn main() {

## Batched msm

For batch msm, simply allocate the results array with size corresponding to batch size and set the `are_bases_shared` flag in config struct.
For batch msm, simply allocate the results array with size corresponding to batch size and set the `are_points_shared_in_batch` flag in config struct.

## Precomputationg

Expand Down
2 changes: 1 addition & 1 deletion icicle/backend/cpu/src/curve/cpu_msm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ eIcicleError cpu_msm(

for (int i = 0; i < config.batch_size; i++) {
int batch_start_idx = msm_size * i;
int bases_start_idx = config.are_bases_shared ? 0 : batch_start_idx;
int bases_start_idx = config.are_points_shared_in_batch ? 0 : batch_start_idx;
msm->run_msm(&scalars[batch_start_idx], &bases[bases_start_idx], msm_size, i, &results[i]);
}
delete msm;
Expand Down
62 changes: 47 additions & 15 deletions icicle/include/icicle/api/babybear.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,37 @@ extern "C" void babybear_generate_scalars(babybear::scalar_t* scalars, int size)
extern "C" void babybear_scalar_convert_montgomery(
const babybear::scalar_t* input, uint64_t size, bool is_into, const VecOpsConfig& config, babybear::scalar_t* output);

extern "C" eIcicleError babybear_ntt_init_domain(
babybear::scalar_t* primitive_root, const NTTInitDomainConfig& config);
extern "C" eIcicleError babybear_ntt_init_domain(babybear::scalar_t* primitive_root, const NTTInitDomainConfig& config);

extern "C" eIcicleError babybear_ntt(
const babybear::scalar_t* input, int size, NTTDir dir, const NTTConfig<babybear::scalar_t>& config, babybear::scalar_t* output);
const babybear::scalar_t* input,
int size,
NTTDir dir,
const NTTConfig<babybear::scalar_t>& config,
babybear::scalar_t* output);

extern "C" eIcicleError babybear_ntt_release_domain();

extern "C" eIcicleError babybear_extension_vector_mul(
const babybear::extension_t* vec_a, const babybear::extension_t* vec_b, uint64_t n, const VecOpsConfig& config, babybear::extension_t* result);
const babybear::extension_t* vec_a,
const babybear::extension_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
babybear::extension_t* result);

extern "C" eIcicleError babybear_extension_vector_add(
const babybear::extension_t* vec_a, const babybear::extension_t* vec_b, uint64_t n, const VecOpsConfig& config, babybear::extension_t* result);
const babybear::extension_t* vec_a,
const babybear::extension_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
babybear::extension_t* result);

extern "C" eIcicleError babybear_extension_vector_sub(
const babybear::extension_t* vec_a, const babybear::extension_t* vec_b, uint64_t n, const VecOpsConfig& config, babybear::extension_t* result);
const babybear::extension_t* vec_a,
const babybear::extension_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
babybear::extension_t* result);

extern "C" eIcicleError babybear_extension_matrix_transpose(
const babybear::extension_t* input,
Expand All @@ -41,20 +56,35 @@ extern "C" eIcicleError babybear_extension_matrix_transpose(
extern "C" eIcicleError babybear_extension_bit_reverse(
const babybear::extension_t* input, uint64_t n, const VecOpsConfig& config, babybear::extension_t* output);


extern "C" void babybear_extension_generate_scalars(babybear::extension_t* scalars, int size);

extern "C" eIcicleError babybear_extension_scalar_convert_montgomery(
const babybear::extension_t* input, uint64_t size, bool is_into, const VecOpsConfig& config, babybear::extension_t* output);
extern "C" eIcicleError babybear_extension_scalar_convert_montgomery(
const babybear::extension_t* input,
uint64_t size,
bool is_into,
const VecOpsConfig& config,
babybear::extension_t* output);

extern "C" eIcicleError babybear_vector_mul(
const babybear::scalar_t* vec_a, const babybear::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, babybear::scalar_t* result);
const babybear::scalar_t* vec_a,
const babybear::scalar_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
babybear::scalar_t* result);

extern "C" eIcicleError babybear_vector_add(
const babybear::scalar_t* vec_a, const babybear::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, babybear::scalar_t* result);
const babybear::scalar_t* vec_a,
const babybear::scalar_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
babybear::scalar_t* result);

extern "C" eIcicleError babybear_vector_sub(
const babybear::scalar_t* vec_a, const babybear::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, babybear::scalar_t* result);
const babybear::scalar_t* vec_a,
const babybear::scalar_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
babybear::scalar_t* result);

extern "C" eIcicleError babybear_matrix_transpose(
const babybear::scalar_t* input,
Expand All @@ -66,7 +96,9 @@ extern "C" eIcicleError babybear_matrix_transpose(
extern "C" eIcicleError babybear_bit_reverse(
const babybear::scalar_t* input, uint64_t n, const VecOpsConfig& config, babybear::scalar_t* output);


extern "C" eIcicleError babybear_extension_ntt(
const babybear::extension_t* input, int size, NTTDir dir, const NTTConfig<babybear::scalar_t>& config, babybear::extension_t* output);

const babybear::extension_t* input,
int size,
NTTDir dir,
const NTTConfig<babybear::scalar_t>& config,
babybear::extension_t* output);
82 changes: 59 additions & 23 deletions icicle/include/icicle/api/bls12_377.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@
#include "icicle/vec_ops.h"

extern "C" eIcicleError bls12_377_precompute_msm_bases(
const bls12_377::affine_t* bases,
int nof_bases,
const MSMConfig& config,
bls12_377::affine_t* output_bases);
const bls12_377::affine_t* bases, int nof_bases, const MSMConfig& config, bls12_377::affine_t* output_bases);

extern "C" eIcicleError bls12_377_msm(
const bls12_377::scalar_t* scalars, const bls12_377::affine_t* points, int msm_size, const MSMConfig& config, bls12_377::projective_t* out);
const bls12_377::scalar_t* scalars,
const bls12_377::affine_t* points,
int msm_size,
const MSMConfig& config,
bls12_377::projective_t* out);

extern "C" eIcicleError bls12_377_g2_precompute_msm_bases(
const bls12_377::g2_affine_t* bases,
int nof_bases,
const MSMConfig& config,
bls12_377::g2_affine_t* output_bases);
const bls12_377::g2_affine_t* bases, int nof_bases, const MSMConfig& config, bls12_377::g2_affine_t* output_bases);

extern "C" eIcicleError bls12_377_g2_msm(
const bls12_377::scalar_t* scalars, const bls12_377::g2_affine_t* points, int msm_size, const MSMConfig& config, bls12_377::g2_projective_t* out);
const bls12_377::scalar_t* scalars,
const bls12_377::g2_affine_t* points,
int msm_size,
const MSMConfig& config,
bls12_377::g2_projective_t* out);

extern "C" bool bls12_377_eq(bls12_377::projective_t* point1, bls12_377::projective_t* point2);

Expand All @@ -40,7 +42,11 @@ extern "C" eIcicleError bls12_377_affine_convert_montgomery(
const bls12_377::affine_t* input, size_t n, bool is_into, const VecOpsConfig& config, bls12_377::affine_t* output);

extern "C" eIcicleError bls12_377_projective_convert_montgomery(
const bls12_377::projective_t* input, size_t n, bool is_into, const VecOpsConfig& config, bls12_377::projective_t* output);
const bls12_377::projective_t* input,
size_t n,
bool is_into,
const VecOpsConfig& config,
bls12_377::projective_t* output);

extern "C" bool bls12_377_g2_eq(bls12_377::g2_projective_t* point1, bls12_377::g2_projective_t* point2);

Expand All @@ -51,35 +57,67 @@ extern "C" void bls12_377_g2_generate_projective_points(bls12_377::g2_projective
extern "C" void bls12_377_g2_generate_affine_points(bls12_377::g2_affine_t* points, int size);

extern "C" eIcicleError bls12_377_g2_affine_convert_montgomery(
const bls12_377::g2_affine_t* input, size_t n, bool is_into, const VecOpsConfig& config, bls12_377::g2_affine_t* output);
const bls12_377::g2_affine_t* input,
size_t n,
bool is_into,
const VecOpsConfig& config,
bls12_377::g2_affine_t* output);

extern "C" eIcicleError bls12_377_g2_projective_convert_montgomery(
const bls12_377::g2_projective_t* input, size_t n, bool is_into, const VecOpsConfig& config, bls12_377::g2_projective_t* output);
const bls12_377::g2_projective_t* input,
size_t n,
bool is_into,
const VecOpsConfig& config,
bls12_377::g2_projective_t* output);

extern "C" eIcicleError bls12_377_ecntt(
const bls12_377::projective_t* input, int size, NTTDir dir, const NTTConfig<bls12_377::scalar_t>& config, bls12_377::projective_t* output);
const bls12_377::projective_t* input,
int size,
NTTDir dir,
const NTTConfig<bls12_377::scalar_t>& config,
bls12_377::projective_t* output);

extern "C" void bls12_377_generate_scalars(bls12_377::scalar_t* scalars, int size);

extern "C" void bls12_377_scalar_convert_montgomery(
const bls12_377::scalar_t* input, uint64_t size, bool is_into, const VecOpsConfig& config, bls12_377::scalar_t* output);
const bls12_377::scalar_t* input,
uint64_t size,
bool is_into,
const VecOpsConfig& config,
bls12_377::scalar_t* output);

extern "C" eIcicleError bls12_377_ntt_init_domain(
bls12_377::scalar_t* primitive_root, const NTTInitDomainConfig& config);
extern "C" eIcicleError
bls12_377_ntt_init_domain(bls12_377::scalar_t* primitive_root, const NTTInitDomainConfig& config);

extern "C" eIcicleError bls12_377_ntt(
const bls12_377::scalar_t* input, int size, NTTDir dir, const NTTConfig<bls12_377::scalar_t>& config, bls12_377::scalar_t* output);
const bls12_377::scalar_t* input,
int size,
NTTDir dir,
const NTTConfig<bls12_377::scalar_t>& config,
bls12_377::scalar_t* output);

extern "C" eIcicleError bls12_377_ntt_release_domain();

extern "C" eIcicleError bls12_377_vector_mul(
const bls12_377::scalar_t* vec_a, const bls12_377::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, bls12_377::scalar_t* result);
const bls12_377::scalar_t* vec_a,
const bls12_377::scalar_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
bls12_377::scalar_t* result);

extern "C" eIcicleError bls12_377_vector_add(
const bls12_377::scalar_t* vec_a, const bls12_377::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, bls12_377::scalar_t* result);
const bls12_377::scalar_t* vec_a,
const bls12_377::scalar_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
bls12_377::scalar_t* result);

extern "C" eIcicleError bls12_377_vector_sub(
const bls12_377::scalar_t* vec_a, const bls12_377::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, bls12_377::scalar_t* result);
const bls12_377::scalar_t* vec_a,
const bls12_377::scalar_t* vec_b,
uint64_t n,
const VecOpsConfig& config,
bls12_377::scalar_t* result);

extern "C" eIcicleError bls12_377_matrix_transpose(
const bls12_377::scalar_t* input,
Expand All @@ -90,5 +128,3 @@ extern "C" eIcicleError bls12_377_matrix_transpose(

extern "C" eIcicleError bls12_377_bit_reverse(
const bls12_377::scalar_t* input, uint64_t n, const VecOpsConfig& config, bls12_377::scalar_t* output);


Loading
Loading