Skip to content

Commit

Permalink
update ntt api to accept config by const
Browse files Browse the repository at this point in the history
  • Loading branch information
yshekel committed Aug 1, 2024
1 parent 7b833af commit 73dd80b
Show file tree
Hide file tree
Showing 20 changed files with 44 additions and 61 deletions.
9 changes: 5 additions & 4 deletions icicle_v3/backend/cpu/include/cpu_ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ namespace ntt_cpu {

template <typename U, typename E>
eIcicleError
cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output);
cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig<S>& config, E* output);

template <typename U, typename E>
eIcicleError
cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output);
cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig<S>& config, E* output);

const S* get_twiddles() const { return twiddles.get(); }
const int get_max_size() const { return max_size; }
Expand Down Expand Up @@ -233,7 +233,7 @@ namespace ntt_cpu {

template <typename S = scalar_t, typename E = scalar_t>
eIcicleError
cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output)
cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig<S>& config, E* output)
{
if (size & (size - 1)) {
ICICLE_LOG_ERROR << "Size must be a power of 2. Size = " << size;
Expand Down Expand Up @@ -350,7 +350,8 @@ namespace ntt_cpu {
}

template <typename S = scalar_t, typename E = scalar_t>
eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output)
eIcicleError
cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig<S>& config, E* output)
{
return cpu_ntt_ref(device, input, size, dir, config, output);
}
Expand Down
2 changes: 1 addition & 1 deletion icicle_v3/backend/cpu/src/curve/cpu_ecntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using namespace curve_config;
using namespace icicle;

template <typename S, typename E>
eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
eIcicleError cpu_ntt(const Device& device, const E* input, int size, NTTDir dir, const NTTConfig<S>& config, E* output)
{
auto err = ntt_cpu::cpu_ntt<S, E>(device, input, size, dir, config, output);
return err;
Expand Down
3 changes: 2 additions & 1 deletion icicle_v3/backend/cpu/src/field/cpu_ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ eIcicleError cpu_get_root_of_unity_from_domain(const Device& device, uint64_t lo
}

template <typename S, typename E>
eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output)
eIcicleError
cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, const NTTConfig<S>& config, E* output)
{
auto err = ntt_cpu::cpu_ntt<S, E>(device, input, size, dir, config, output);
return err;
Expand Down
10 changes: 2 additions & 8 deletions icicle_v3/include/icicle/api/babybear.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ extern "C" eIcicleError babybear_ntt(
const babybear::scalar_t* input,
int size,
NTTDir dir,
NTTConfig<babybear::scalar_t>& config,
const NTTConfig<babybear::scalar_t>& config,
babybear::scalar_t* output);

extern "C" eIcicleError babybear_ntt_release_domain();
Expand All @@ -43,9 +43,6 @@ extern "C" eIcicleError babybearextension_vector_add(
const VecOpsConfig& config,
babybear::extension_t* result);

// extern "C" eIcicleError babybear_extension_accumulate_cuda(
// const babybear::extension_t* vec_a, const babybear::extension_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError babybear_extension_vector_sub(
const babybear::extension_t* vec_a,
const babybear::extension_t* vec_b,
Expand All @@ -67,7 +64,7 @@ extern "C" eIcicleError babybear_extension_ntt(
const babybear::extension_t* input,
int size,
NTTDir dir,
NTTConfig<babybear::scalar_t>& config,
const NTTConfig<babybear::scalar_t>& config,
babybear::extension_t* output);

extern "C" void babybear_generate_scalars(babybear::scalar_t* scalars, int size);
Expand All @@ -89,9 +86,6 @@ extern "C" eIcicleError babybearvector_add(
const VecOpsConfig& config,
babybear::scalar_t* result);

// extern "C" eIcicleError babybear_accumulate_cuda(
// const babybear::scalar_t* vec_a, const babybear::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError babybear_vector_sub(
const babybear::scalar_t* vec_a,
const babybear::scalar_t* vec_b,
Expand Down
7 changes: 2 additions & 5 deletions icicle_v3/include/icicle/api/bls12_377.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ extern "C" eIcicleError bls12_377_ecntt(
const bls12_377::projective_t* input,
int size,
NTTDir dir,
NTTConfig<bls12_377::scalar_t>& config,
const NTTConfig<bls12_377::scalar_t>& config,
bls12_377::projective_t* output);

extern "C" eIcicleError
Expand All @@ -84,7 +84,7 @@ extern "C" eIcicleError bls12_377_ntt(
const bls12_377::scalar_t* input,
int size,
NTTDir dir,
NTTConfig<bls12_377::scalar_t>& config,
const NTTConfig<bls12_377::scalar_t>& config,
bls12_377::scalar_t* output);

extern "C" eIcicleError bls12_377_ntt_release_domain();
Expand Down Expand Up @@ -112,9 +112,6 @@ extern "C" eIcicleError bls12_377vector_add(
const VecOpsConfig& config,
bls12_377::scalar_t* result);

// extern "C" eIcicleError bls12_377_accumulate_cuda(
// const bls12_377::scalar_t* vec_a, const bls12_377::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError bls12_377_vector_sub(
const bls12_377::scalar_t* vec_a,
const bls12_377::scalar_t* vec_b,
Expand Down
7 changes: 2 additions & 5 deletions icicle_v3/include/icicle/api/bls12_381.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ extern "C" eIcicleError bls12_381_ecntt(
const bls12_381::projective_t* input,
int size,
NTTDir dir,
NTTConfig<bls12_381::scalar_t>& config,
const NTTConfig<bls12_381::scalar_t>& config,
bls12_381::projective_t* output);

extern "C" eIcicleError
Expand All @@ -84,7 +84,7 @@ extern "C" eIcicleError bls12_381_ntt(
const bls12_381::scalar_t* input,
int size,
NTTDir dir,
NTTConfig<bls12_381::scalar_t>& config,
const NTTConfig<bls12_381::scalar_t>& config,
bls12_381::scalar_t* output);

extern "C" eIcicleError bls12_381_ntt_release_domain();
Expand Down Expand Up @@ -112,9 +112,6 @@ extern "C" eIcicleError bls12_381vector_add(
const VecOpsConfig& config,
bls12_381::scalar_t* result);

// extern "C" eIcicleError bls12_381_accumulate_cuda(
// const bls12_381::scalar_t* vec_a, const bls12_381::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError bls12_381_vector_sub(
const bls12_381::scalar_t* vec_a,
const bls12_381::scalar_t* vec_b,
Expand Down
11 changes: 6 additions & 5 deletions icicle_v3/include/icicle/api/bn254.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ extern "C" eIcicleError bn254_ecntt(
const bn254::projective_t* input,
int size,
NTTDir dir,
NTTConfig<bn254::scalar_t>& config,
const NTTConfig<bn254::scalar_t>& config,
bn254::projective_t* output);

extern "C" eIcicleError bn254_ntt_init_domain(bn254::scalar_t* primitive_root, const NTTInitDomainConfig& config);

extern "C" eIcicleError bn254_ntt(
const bn254::scalar_t* input, int size, NTTDir dir, NTTConfig<bn254::scalar_t>& config, bn254::scalar_t* output);
const bn254::scalar_t* input,
int size,
NTTDir dir,
const NTTConfig<bn254::scalar_t>& config,
bn254::scalar_t* output);

extern "C" eIcicleError bn254_ntt_release_domain();

Expand All @@ -95,9 +99,6 @@ extern "C" eIcicleError bn254vector_add(
const VecOpsConfig& config,
bn254::scalar_t* result);

// extern "C" eIcicleError bn254_accumulate_cuda(
// const bn254::scalar_t* vec_a, const bn254::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError bn254_vector_sub(
const bn254::scalar_t* vec_a,
const bn254::scalar_t* vec_b,
Expand Down
7 changes: 2 additions & 5 deletions icicle_v3/include/icicle/api/bw6_761.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ extern "C" eIcicleError bw6_761_ecntt(
const bw6_761::projective_t* input,
int size,
NTTDir dir,
NTTConfig<bw6_761::scalar_t>& config,
const NTTConfig<bw6_761::scalar_t>& config,
bw6_761::projective_t* output);

extern "C" eIcicleError bw6_761_ntt_init_domain(bw6_761::scalar_t* primitive_root, const NTTInitDomainConfig& config);
Expand All @@ -79,7 +79,7 @@ extern "C" eIcicleError bw6_761_ntt(
const bw6_761::scalar_t* input,
int size,
NTTDir dir,
NTTConfig<bw6_761::scalar_t>& config,
const NTTConfig<bw6_761::scalar_t>& config,
bw6_761::scalar_t* output);

extern "C" eIcicleError bw6_761_ntt_release_domain();
Expand All @@ -103,9 +103,6 @@ extern "C" eIcicleError bw6_761vector_add(
const VecOpsConfig& config,
bw6_761::scalar_t* result);

// extern "C" eIcicleError bw6_761_accumulate_cuda(
// const bw6_761::scalar_t* vec_a, const bw6_761::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError bw6_761_vector_sub(
const bw6_761::scalar_t* vec_a,
const bw6_761::scalar_t* vec_b,
Expand Down
3 changes: 0 additions & 3 deletions icicle_v3/include/icicle/api/grumpkin.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ extern "C" eIcicleError grumpkinvector_add(
const VecOpsConfig& config,
grumpkin::scalar_t* result);

// extern "C" eIcicleError grumpkin_accumulate_cuda(
// const grumpkin::scalar_t* vec_a, const grumpkin::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError grumpkin_vector_sub(
const grumpkin::scalar_t* vec_a,
const grumpkin::scalar_t* vec_b,
Expand Down
5 changes: 1 addition & 4 deletions icicle_v3/include/icicle/api/stark252.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ extern "C" eIcicleError stark252_ntt(
const stark252::scalar_t* input,
int size,
NTTDir dir,
NTTConfig<stark252::scalar_t>& config,
const NTTConfig<stark252::scalar_t>& config,
stark252::scalar_t* output);

extern "C" eIcicleError stark252_ntt_release_domain();
Expand All @@ -39,9 +39,6 @@ extern "C" eIcicleError stark252vector_add(
const VecOpsConfig& config,
stark252::scalar_t* result);

// extern "C" eIcicleError stark252_accumulate_cuda(
// const stark252::scalar_t* vec_a, const stark252::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError stark252_vector_sub(
const stark252::scalar_t* vec_a,
const stark252::scalar_t* vec_b,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
extern "C" eIcicleError ${CURVE}_ecntt(
const ${CURVE}::projective_t* input, int size, NTTDir dir, NTTConfig<${CURVE}::scalar_t>& config, ${CURVE}::projective_t* output);
const ${CURVE}::projective_t* input, int size, NTTDir dir, const NTTConfig<${CURVE}::scalar_t>& config, ${CURVE}::projective_t* output);
2 changes: 1 addition & 1 deletion icicle_v3/include/icicle/api/templates/fields/ntt.template
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ extern "C" eIcicleError ${FIELD}_ntt_init_domain(
${FIELD}::scalar_t* primitive_root, const NTTInitDomainConfig& config);

extern "C" eIcicleError ${FIELD}_ntt(
const ${FIELD}::scalar_t* input, int size, NTTDir dir, NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::scalar_t* output);
const ${FIELD}::scalar_t* input, int size, NTTDir dir, const NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::scalar_t* output);

extern "C" eIcicleError ${FIELD}_ntt_release_domain();
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
extern "C" eIcicleError ${FIELD}_extension_ntt(
const ${FIELD}::extension_t* input, int size, NTTDir dir, NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::extension_t* output);
const ${FIELD}::extension_t* input, int size, NTTDir dir, const NTTConfig<${FIELD}::scalar_t>& config, ${FIELD}::extension_t* output);
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ extern "C" eIcicleError ${FIELD}_vector_mul(
extern "C" eIcicleError ${FIELD}vector_add(
const ${FIELD}::scalar_t* vec_a, const ${FIELD}::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::scalar_t* result);

// extern "C" eIcicleError ${FIELD}_accumulate_cuda(
// const ${FIELD}::scalar_t* vec_a, const ${FIELD}::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError ${FIELD}_vector_sub(
const ${FIELD}::scalar_t* vec_a, const ${FIELD}::scalar_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::scalar_t* result);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ extern "C" eIcicleError ${FIELD}_extension_vector_mul(
extern "C" eIcicleError ${FIELD}extension_vector_add(
const ${FIELD}::extension_t* vec_a, const ${FIELD}::extension_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::extension_t* result);

// extern "C" eIcicleError ${FIELD}_extension_accumulate_cuda(
// const ${FIELD}::extension_t* vec_a, const ${FIELD}::extension_t* vec_b, uint64_t n, const VecOpsConfig& config);

extern "C" eIcicleError ${FIELD}_extension_vector_sub(
const ${FIELD}::extension_t* vec_a, const ${FIELD}::extension_t* vec_b, uint64_t n, const VecOpsConfig& config, ${FIELD}::extension_t* result);

Expand Down
2 changes: 1 addition & 1 deletion icicle_v3/include/icicle/backend/ecntt_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace icicle {
const projective_t* input,
int size,
NTTDir dir,
NTTConfig<scalar_t>& config,
const NTTConfig<scalar_t>& config,
projective_t* output)>;

void register_ecntt(const std::string& deviceType, ECNttFieldImpl impl);
Expand Down
9 changes: 7 additions & 2 deletions icicle_v3/include/icicle/backend/ntt_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ namespace icicle {

/*************************** NTT ***************************/
using NttImpl = std::function<eIcicleError(
const Device& device, const scalar_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, scalar_t* output)>;
const Device& device,
const scalar_t* input,
int size,
NTTDir dir,
const NTTConfig<scalar_t>& config,
scalar_t* output)>;

void register_ntt(const std::string& deviceType, NttImpl impl);

Expand All @@ -29,7 +34,7 @@ namespace icicle {
const extension_t* input,
int size,
NTTDir dir,
NTTConfig<scalar_t>& config,
const NTTConfig<scalar_t>& config,
extension_t* output)>;

void register_extension_ntt(const std::string& deviceType, NttExtFieldImpl impl);
Expand Down
2 changes: 1 addition & 1 deletion icicle_v3/include/icicle/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ namespace icicle {
* @return eIcicleError Error code indicating success or failure.
*/
template <typename S, typename E>
eIcicleError ntt(const E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output);
eIcicleError ntt(const E* input, int size, NTTDir dir, const NTTConfig<S>& config, E* output);

/**
* @brief Initializes the NTT domain.
Expand Down
5 changes: 3 additions & 2 deletions icicle_v3/src/ecntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ namespace icicle {
ICICLE_DISPATCHER_INST(ECNttExtFieldDispatcher, ecntt, ECNttFieldImpl);

extern "C" eIcicleError CONCAT_EXPAND(FIELD, ecntt)(
const projective_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, projective_t* output)
const projective_t* input, int size, NTTDir dir, const NTTConfig<scalar_t>& config, projective_t* output)
{
return ECNttExtFieldDispatcher::execute(input, size, dir, config, output);
}

template <>
eIcicleError ntt(const projective_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, projective_t* output)
eIcicleError
ntt(const projective_t* input, int size, NTTDir dir, const NTTConfig<scalar_t>& config, projective_t* output)
{
return CONCAT_EXPAND(FIELD, ecntt)(input, size, dir, config, output);
}
Expand Down
11 changes: 6 additions & 5 deletions icicle_v3/src/ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ namespace icicle {
/*************************** NTT ***************************/
ICICLE_DISPATCHER_INST(NttDispatcher, ntt, NttImpl);

extern "C" eIcicleError
CONCAT_EXPAND(FIELD, ntt)(const scalar_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, scalar_t* output)
extern "C" eIcicleError CONCAT_EXPAND(FIELD, ntt)(
const scalar_t* input, int size, NTTDir dir, const NTTConfig<scalar_t>& config, scalar_t* output)
{
return NttDispatcher::execute(input, size, dir, config, output);
}

template <>
eIcicleError ntt(const scalar_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, scalar_t* output)
eIcicleError ntt(const scalar_t* input, int size, NTTDir dir, const NTTConfig<scalar_t>& config, scalar_t* output)
{
return CONCAT_EXPAND(FIELD, ntt)(input, size, dir, config, output);
}
Expand All @@ -23,13 +23,14 @@ namespace icicle {
ICICLE_DISPATCHER_INST(NttExtFieldDispatcher, extension_ntt, NttExtFieldImpl);

extern "C" eIcicleError CONCAT_EXPAND(FIELD, extension_ntt)(
const extension_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, extension_t* output)
const extension_t* input, int size, NTTDir dir, const NTTConfig<scalar_t>& config, extension_t* output)
{
return NttExtFieldDispatcher::execute(input, size, dir, config, output);
}

template <>
eIcicleError ntt(const extension_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, extension_t* output)
eIcicleError
ntt(const extension_t* input, int size, NTTDir dir, const NTTConfig<scalar_t>& config, extension_t* output)
{
return CONCAT_EXPAND(FIELD, extension_ntt)(input, size, dir, config, output);
}
Expand Down

0 comments on commit 73dd80b

Please sign in to comment.