From 73dd80b63bfb4193c42f1eed9c75172e052cffe6 Mon Sep 17 00:00:00 2001 From: Yuval Shekel Date: Thu, 1 Aug 2024 15:36:53 +0300 Subject: [PATCH] update ntt api to accept config by const --- icicle_v3/backend/cpu/include/cpu_ntt.h | 9 +++++---- icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp | 2 +- icicle_v3/backend/cpu/src/field/cpu_ntt.cpp | 3 ++- icicle_v3/include/icicle/api/babybear.h | 10 ++-------- icicle_v3/include/icicle/api/bls12_377.h | 7 ++----- icicle_v3/include/icicle/api/bls12_381.h | 7 ++----- icicle_v3/include/icicle/api/bn254.h | 11 ++++++----- icicle_v3/include/icicle/api/bw6_761.h | 7 ++----- icicle_v3/include/icicle/api/grumpkin.h | 3 --- icicle_v3/include/icicle/api/stark252.h | 5 +---- .../icicle/api/templates/curves/ecntt.template | 2 +- .../include/icicle/api/templates/fields/ntt.template | 2 +- .../icicle/api/templates/fields/ntt_ext.template | 2 +- .../icicle/api/templates/fields/vec_ops.template | 3 --- .../icicle/api/templates/fields/vec_ops_ext.template | 3 --- icicle_v3/include/icicle/backend/ecntt_backend.h | 2 +- icicle_v3/include/icicle/backend/ntt_backend.h | 9 +++++++-- icicle_v3/include/icicle/ntt.h | 2 +- icicle_v3/src/ecntt.cpp | 5 +++-- icicle_v3/src/ntt.cpp | 11 ++++++----- 20 files changed, 44 insertions(+), 61 deletions(-) diff --git a/icicle_v3/backend/cpu/include/cpu_ntt.h b/icicle_v3/backend/cpu/include/cpu_ntt.h index 839936d37..7ab52631b 100644 --- a/icicle_v3/backend/cpu/include/cpu_ntt.h +++ b/icicle_v3/backend/cpu/include/cpu_ntt.h @@ -35,11 +35,11 @@ namespace ntt_cpu { template eIcicleError - cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output); + cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig& config, E* output); template eIcicleError - cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output); + cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig& config, E* output); const S* get_twiddles() const { return twiddles.get(); } const int get_max_size() const { return max_size; } @@ -233,7 +233,7 @@ namespace ntt_cpu { template eIcicleError - cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output) + cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig& config, E* output) { if (size & (size - 1)) { ICICLE_LOG_ERROR << "Size must be a power of 2. Size = " << size; @@ -350,7 +350,8 @@ namespace ntt_cpu { } template - eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output) + eIcicleError + cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig& config, E* output) { return cpu_ntt_ref(device, input, size, dir, config, output); } diff --git a/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp b/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp index 73285ef89..85ac07431 100644 --- a/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp +++ b/icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp @@ -10,7 +10,7 @@ using namespace curve_config; using namespace icicle; template -eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, NTTConfig& config, E* output) +eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, const NTTConfig& config, E* output) { auto err = ntt_cpu::cpu_ntt(device, input, size, dir, config, output); return err; diff --git a/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp b/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp index 7a07196e9..a1a857d4e 100644 --- a/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp +++ b/icicle_v3/backend/cpu/src/field/cpu_ntt.cpp @@ -25,7 +25,8 @@ eIcicleError cpu_get_root_of_unity_from_domain(const Device& device, uint64_t lo } template -eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output) +eIcicleError +cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig& config, E* output) { auto err = ntt_cpu::cpu_ntt(device, input, size, dir, config, output); return err; diff --git a/icicle_v3/include/icicle/api/babybear.h b/icicle_v3/include/icicle/api/babybear.h index e496df187..1e83d9892 100644 --- a/icicle_v3/include/icicle/api/babybear.h +++ b/icicle_v3/include/icicle/api/babybear.h @@ -24,7 +24,7 @@ extern "C" eIcicleError babybear_ntt( const babybear::scalar_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, babybear::scalar_t* output); extern "C" eIcicleError babybear_ntt_release_domain(); @@ -43,9 +43,6 @@ extern "C" eIcicleError babybearextension_vector_add( const VecOpsConfig& config, babybear::extension_t* result); -// extern "C" eIcicleError babybear_extension_accumulate_cuda( -// const babybear::extension_t* vec_a, const babybear::extension_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError babybear_extension_vector_sub( const babybear::extension_t* vec_a, const babybear::extension_t* vec_b, @@ -67,7 +64,7 @@ extern "C" eIcicleError babybear_extension_ntt( const babybear::extension_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, babybear::extension_t* output); extern "C" void babybear_generate_scalars(babybear::scalar_t* scalars, int size); @@ -89,9 +86,6 @@ extern "C" eIcicleError babybearvector_add( const VecOpsConfig& config, babybear::scalar_t* result); -// extern "C" eIcicleError babybear_accumulate_cuda( -// const babybear::scalar_t* vec_a, const babybear::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError babybear_vector_sub( const babybear::scalar_t* vec_a, const babybear::scalar_t* vec_b, diff --git a/icicle_v3/include/icicle/api/bls12_377.h b/icicle_v3/include/icicle/api/bls12_377.h index ef9c8fc85..b3298ec11 100644 --- a/icicle_v3/include/icicle/api/bls12_377.h +++ b/icicle_v3/include/icicle/api/bls12_377.h @@ -74,7 +74,7 @@ extern "C" eIcicleError bls12_377_ecntt( const bls12_377::projective_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, bls12_377::projective_t* output); extern "C" eIcicleError @@ -84,7 +84,7 @@ extern "C" eIcicleError bls12_377_ntt( const bls12_377::scalar_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, bls12_377::scalar_t* output); extern "C" eIcicleError bls12_377_ntt_release_domain(); @@ -112,9 +112,6 @@ extern "C" eIcicleError bls12_377vector_add( const VecOpsConfig& config, bls12_377::scalar_t* result); -// extern "C" eIcicleError bls12_377_accumulate_cuda( -// const bls12_377::scalar_t* vec_a, const bls12_377::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError bls12_377_vector_sub( const bls12_377::scalar_t* vec_a, const bls12_377::scalar_t* vec_b, diff --git a/icicle_v3/include/icicle/api/bls12_381.h b/icicle_v3/include/icicle/api/bls12_381.h index 21d8fb59f..1c3b93241 100644 --- a/icicle_v3/include/icicle/api/bls12_381.h +++ b/icicle_v3/include/icicle/api/bls12_381.h @@ -74,7 +74,7 @@ extern "C" eIcicleError bls12_381_ecntt( const bls12_381::projective_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, bls12_381::projective_t* output); extern "C" eIcicleError @@ -84,7 +84,7 @@ extern "C" eIcicleError bls12_381_ntt( const bls12_381::scalar_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, bls12_381::scalar_t* output); extern "C" eIcicleError bls12_381_ntt_release_domain(); @@ -112,9 +112,6 @@ extern "C" eIcicleError bls12_381vector_add( const VecOpsConfig& config, bls12_381::scalar_t* result); -// extern "C" eIcicleError bls12_381_accumulate_cuda( -// const bls12_381::scalar_t* vec_a, const bls12_381::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError bls12_381_vector_sub( const bls12_381::scalar_t* vec_a, const bls12_381::scalar_t* vec_b, diff --git a/icicle_v3/include/icicle/api/bn254.h b/icicle_v3/include/icicle/api/bn254.h index b881a8523..d81459825 100644 --- a/icicle_v3/include/icicle/api/bn254.h +++ b/icicle_v3/include/icicle/api/bn254.h @@ -66,13 +66,17 @@ extern "C" eIcicleError bn254_ecntt( const bn254::projective_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, bn254::projective_t* output); extern "C" eIcicleError bn254_ntt_init_domain(bn254::scalar_t* primitive_root, const NTTInitDomainConfig& config); extern "C" eIcicleError bn254_ntt( - const bn254::scalar_t* input, int size, NTTDir dir, NTTConfig& config, bn254::scalar_t* output); + const bn254::scalar_t* input, + int size, + NTTDir dir, + const NTTConfig& config, + bn254::scalar_t* output); extern "C" eIcicleError bn254_ntt_release_domain(); @@ -95,9 +99,6 @@ extern "C" eIcicleError bn254vector_add( const VecOpsConfig& config, bn254::scalar_t* result); -// extern "C" eIcicleError bn254_accumulate_cuda( -// const bn254::scalar_t* vec_a, const bn254::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError bn254_vector_sub( const bn254::scalar_t* vec_a, const bn254::scalar_t* vec_b, diff --git a/icicle_v3/include/icicle/api/bw6_761.h b/icicle_v3/include/icicle/api/bw6_761.h index 7f5479558..67ecc1cdb 100644 --- a/icicle_v3/include/icicle/api/bw6_761.h +++ b/icicle_v3/include/icicle/api/bw6_761.h @@ -70,7 +70,7 @@ extern "C" eIcicleError bw6_761_ecntt( const bw6_761::projective_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, bw6_761::projective_t* output); extern "C" eIcicleError bw6_761_ntt_init_domain(bw6_761::scalar_t* primitive_root, const NTTInitDomainConfig& config); @@ -79,7 +79,7 @@ extern "C" eIcicleError bw6_761_ntt( const bw6_761::scalar_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, bw6_761::scalar_t* output); extern "C" eIcicleError bw6_761_ntt_release_domain(); @@ -103,9 +103,6 @@ extern "C" eIcicleError bw6_761vector_add( const VecOpsConfig& config, bw6_761::scalar_t* result); -// extern "C" eIcicleError bw6_761_accumulate_cuda( -// const bw6_761::scalar_t* vec_a, const bw6_761::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError bw6_761_vector_sub( const bw6_761::scalar_t* vec_a, const bw6_761::scalar_t* vec_b, diff --git a/icicle_v3/include/icicle/api/grumpkin.h b/icicle_v3/include/icicle/api/grumpkin.h index a2930772d..4e67ecdf0 100644 --- a/icicle_v3/include/icicle/api/grumpkin.h +++ b/icicle_v3/include/icicle/api/grumpkin.h @@ -56,9 +56,6 @@ extern "C" eIcicleError grumpkinvector_add( const VecOpsConfig& config, grumpkin::scalar_t* result); -// extern "C" eIcicleError grumpkin_accumulate_cuda( -// const grumpkin::scalar_t* vec_a, const grumpkin::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError grumpkin_vector_sub( const grumpkin::scalar_t* vec_a, const grumpkin::scalar_t* vec_b, diff --git a/icicle_v3/include/icicle/api/stark252.h b/icicle_v3/include/icicle/api/stark252.h index 4cbc82319..39ad5d7f9 100644 --- a/icicle_v3/include/icicle/api/stark252.h +++ b/icicle_v3/include/icicle/api/stark252.h @@ -15,7 +15,7 @@ extern "C" eIcicleError stark252_ntt( const stark252::scalar_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, stark252::scalar_t* output); extern "C" eIcicleError stark252_ntt_release_domain(); @@ -39,9 +39,6 @@ extern "C" eIcicleError stark252vector_add( const VecOpsConfig& config, stark252::scalar_t* result); -// extern "C" eIcicleError stark252_accumulate_cuda( -// const stark252::scalar_t* vec_a, const stark252::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError stark252_vector_sub( const stark252::scalar_t* vec_a, const stark252::scalar_t* vec_b, diff --git a/icicle_v3/include/icicle/api/templates/curves/ecntt.template b/icicle_v3/include/icicle/api/templates/curves/ecntt.template index 2b81ce4a8..0f8ad12a5 100644 --- a/icicle_v3/include/icicle/api/templates/curves/ecntt.template +++ b/icicle_v3/include/icicle/api/templates/curves/ecntt.template @@ -1,2 +1,2 @@ extern "C" eIcicleError ${CURVE}_ecntt( - const ${CURVE}::projective_t* input, int size, NTTDir dir, NTTConfig<${CURVE}::scalar_t>& config, ${CURVE}::projective_t* output); \ No newline at end of file + const ${CURVE}::projective_t* input, int size, NTTDir dir, const NTTConfig<${CURVE}::scalar_t>& config, ${CURVE}::projective_t* output); \ No newline at end of file diff --git a/icicle_v3/include/icicle/api/templates/fields/ntt.template b/icicle_v3/include/icicle/api/templates/fields/ntt.template index 78b6beb70..eda083d50 100644 --- a/icicle_v3/include/icicle/api/templates/fields/ntt.template +++ b/icicle_v3/include/icicle/api/templates/fields/ntt.template @@ -2,6 +2,6 @@ extern "C" eIcicleError ${FIELD}_ntt_init_domain( ${FIELD}::scalar_t* primitive_root, const NTTInitDomainConfig& config); extern "C" eIcicleError ${FIELD}_ntt( - const ${FIELD}::scalar_t* input, int size, NTTDir dir, NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::scalar_t* output); + const ${FIELD}::scalar_t* input, int size, NTTDir dir, const NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::scalar_t* output); extern "C" eIcicleError ${FIELD}_ntt_release_domain(); \ No newline at end of file diff --git a/icicle_v3/include/icicle/api/templates/fields/ntt_ext.template b/icicle_v3/include/icicle/api/templates/fields/ntt_ext.template index e9e6e16be..f6fee3eeb 100644 --- a/icicle_v3/include/icicle/api/templates/fields/ntt_ext.template +++ b/icicle_v3/include/icicle/api/templates/fields/ntt_ext.template @@ -1,2 +1,2 @@ extern "C" eIcicleError ${FIELD}_extension_ntt( - const ${FIELD}::extension_t* input, int size, NTTDir dir, NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::extension_t* output); \ No newline at end of file + const ${FIELD}::extension_t* input, int size, NTTDir dir, const NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::extension_t* output); \ No newline at end of file diff --git a/icicle_v3/include/icicle/api/templates/fields/vec_ops.template b/icicle_v3/include/icicle/api/templates/fields/vec_ops.template index f6f3c327c..7bbd1c00a 100644 --- a/icicle_v3/include/icicle/api/templates/fields/vec_ops.template +++ b/icicle_v3/include/icicle/api/templates/fields/vec_ops.template @@ -4,9 +4,6 @@ extern "C" eIcicleError ${FIELD}_vector_mul( extern "C" eIcicleError ${FIELD}vector_add( const ${FIELD}::scalar_t* vec_a, const ${FIELD}::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::scalar_t* result); -// extern "C" eIcicleError ${FIELD}_accumulate_cuda( -// const ${FIELD}::scalar_t* vec_a, const ${FIELD}::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError ${FIELD}_vector_sub( const ${FIELD}::scalar_t* vec_a, const ${FIELD}::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::scalar_t* result); diff --git a/icicle_v3/include/icicle/api/templates/fields/vec_ops_ext.template b/icicle_v3/include/icicle/api/templates/fields/vec_ops_ext.template index 7032fdc6a..948c4294c 100644 --- a/icicle_v3/include/icicle/api/templates/fields/vec_ops_ext.template +++ b/icicle_v3/include/icicle/api/templates/fields/vec_ops_ext.template @@ -4,9 +4,6 @@ extern "C" eIcicleError ${FIELD}_extension_vector_mul( extern "C" eIcicleError ${FIELD}extension_vector_add( const ${FIELD}::extension_t* vec_a, const ${FIELD}::extension_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::extension_t* result); -// extern "C" eIcicleError ${FIELD}_extension_accumulate_cuda( -// const ${FIELD}::extension_t* vec_a, const ${FIELD}::extension_t* vec_b, uint64_t n, const VecOpsConfig& config); - extern "C" eIcicleError ${FIELD}_extension_vector_sub( const ${FIELD}::extension_t* vec_a, const ${FIELD}::extension_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::extension_t* result); diff --git a/icicle_v3/include/icicle/backend/ecntt_backend.h b/icicle_v3/include/icicle/backend/ecntt_backend.h index 104691600..a56a128b6 100644 --- a/icicle_v3/include/icicle/backend/ecntt_backend.h +++ b/icicle_v3/include/icicle/backend/ecntt_backend.h @@ -18,7 +18,7 @@ namespace icicle { const projective_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, projective_t* output)>; void register_ecntt(const std::string& deviceType, ECNttFieldImpl impl); diff --git a/icicle_v3/include/icicle/backend/ntt_backend.h b/icicle_v3/include/icicle/backend/ntt_backend.h index 6a123bc73..9d61c56d2 100644 --- a/icicle_v3/include/icicle/backend/ntt_backend.h +++ b/icicle_v3/include/icicle/backend/ntt_backend.h @@ -11,7 +11,12 @@ namespace icicle { /*************************** NTT ***************************/ using NttImpl = std::function& config, scalar_t* output)>; + const Device& device, + const scalar_t* input, + int size, + NTTDir dir, + const NTTConfig& config, + scalar_t* output)>; void register_ntt(const std::string& deviceType, NttImpl impl); @@ -29,7 +34,7 @@ namespace icicle { const extension_t* input, int size, NTTDir dir, - NTTConfig& config, + const NTTConfig& config, extension_t* output)>; void register_extension_ntt(const std::string& deviceType, NttExtFieldImpl impl); diff --git a/icicle_v3/include/icicle/ntt.h b/icicle_v3/include/icicle/ntt.h index d745e934c..0d7a827c3 100644 --- a/icicle_v3/include/icicle/ntt.h +++ b/icicle_v3/include/icicle/ntt.h @@ -122,7 +122,7 @@ namespace icicle { * @return eIcicleError Error code indicating success or failure. */ template - eIcicleError ntt(const E* input, int size, NTTDir dir, NTTConfig& config, E* output); + eIcicleError ntt(const E* input, int size, NTTDir dir, const NTTConfig& config, E* output); /** * @brief Initializes the NTT domain. diff --git a/icicle_v3/src/ecntt.cpp b/icicle_v3/src/ecntt.cpp index dc3e9f49b..5132ba2e7 100644 --- a/icicle_v3/src/ecntt.cpp +++ b/icicle_v3/src/ecntt.cpp @@ -6,13 +6,14 @@ namespace icicle { ICICLE_DISPATCHER_INST(ECNttExtFieldDispatcher, ecntt, ECNttFieldImpl); extern "C" eIcicleError CONCAT_EXPAND(FIELD, ecntt)( - const projective_t* input, int size, NTTDir dir, NTTConfig& config, projective_t* output) + const projective_t* input, int size, NTTDir dir, const NTTConfig& config, projective_t* output) { return ECNttExtFieldDispatcher::execute(input, size, dir, config, output); } template <> - eIcicleError ntt(const projective_t* input, int size, NTTDir dir, NTTConfig& config, projective_t* output) + eIcicleError + ntt(const projective_t* input, int size, NTTDir dir, const NTTConfig& config, projective_t* output) { return CONCAT_EXPAND(FIELD, ecntt)(input, size, dir, config, output); } diff --git a/icicle_v3/src/ntt.cpp b/icicle_v3/src/ntt.cpp index 8e9362855..39c894e76 100644 --- a/icicle_v3/src/ntt.cpp +++ b/icicle_v3/src/ntt.cpp @@ -7,14 +7,14 @@ namespace icicle { /*************************** NTT ***************************/ ICICLE_DISPATCHER_INST(NttDispatcher, ntt, NttImpl); - extern "C" eIcicleError - CONCAT_EXPAND(FIELD, ntt)(const scalar_t* input, int size, NTTDir dir, NTTConfig& config, scalar_t* output) + extern "C" eIcicleError CONCAT_EXPAND(FIELD, ntt)( + const scalar_t* input, int size, NTTDir dir, const NTTConfig& config, scalar_t* output) { return NttDispatcher::execute(input, size, dir, config, output); } template <> - eIcicleError ntt(const scalar_t* input, int size, NTTDir dir, NTTConfig& config, scalar_t* output) + eIcicleError ntt(const scalar_t* input, int size, NTTDir dir, const NTTConfig& config, scalar_t* output) { return CONCAT_EXPAND(FIELD, ntt)(input, size, dir, config, output); } @@ -23,13 +23,14 @@ namespace icicle { ICICLE_DISPATCHER_INST(NttExtFieldDispatcher, extension_ntt, NttExtFieldImpl); extern "C" eIcicleError CONCAT_EXPAND(FIELD, extension_ntt)( - const extension_t* input, int size, NTTDir dir, NTTConfig& config, extension_t* output) + const extension_t* input, int size, NTTDir dir, const NTTConfig& config, extension_t* output) { return NttExtFieldDispatcher::execute(input, size, dir, config, output); } template <> - eIcicleError ntt(const extension_t* input, int size, NTTDir dir, NTTConfig& config, extension_t* output) + eIcicleError + ntt(const extension_t* input, int size, NTTDir dir, const NTTConfig& config, extension_t* output) { return CONCAT_EXPAND(FIELD, extension_ntt)(input, size, dir, config, output); }