From 37539377d4b1351dc1128aa525cd7f3dfc9056b5 Mon Sep 17 00:00:00 2001 From: Yuval Shekel Date: Tue, 6 Feb 2024 16:06:08 +0000 Subject: [PATCH] Mixed-radix ntt coset support --- icicle/appUtils/ntt/kernel_ntt.cu | 53 ++++++++++++++++++++--- icicle/appUtils/ntt/ntt.cu | 9 ++-- icicle/appUtils/ntt/ntt_impl.cuh | 2 + icicle/appUtils/ntt/tests/verification.cu | 12 +++-- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/icicle/appUtils/ntt/kernel_ntt.cu b/icicle/appUtils/ntt/kernel_ntt.cu index b261a3f9f..2a4985e4a 100644 --- a/icicle/appUtils/ntt/kernel_ntt.cu +++ b/icicle/appUtils/ntt/kernel_ntt.cu @@ -32,7 +32,7 @@ namespace ntt { // Note: the following reorder kernels are fused with normalization for INTT template static __global__ void - reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) + reorder_digits_inplace_and_normalize_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) { // launch N threads (per batch element) // each thread starts from one index and calculates the corresponding group @@ -66,8 +66,8 @@ namespace ntt { } template - __launch_bounds__(64) __global__ - void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) + __launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel( + E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) { uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; uint32_t rd = tid; @@ -75,6 +75,26 @@ namespace ntt { arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd]; } + template + static __global__ void BatchMulKernelDigReverse( + E* in_vec, + int n_elements, + int batch_size, + S* scalar_vec, + int step, + int n_scalars, + int logn, + bool digit_rev, + bool dit, + E* out_vec) + { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= n_elements * batch_size) return; + int64_t scalar_id = tid % n_elements; + if (digit_rev) scalar_id = dig_rev(tid, logn, dit); + out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid]; + } + template __launch_bounds__(64) __global__ void ntt64( E* in, @@ -604,6 +624,8 @@ namespace ntt { int batch_size, bool is_inverse, Ordering ordering, + S* arbitrary_coset, + int coset_gen_index, cudaStream_t cuda_stream) { CHK_INIT_IF_RETURN(); @@ -624,15 +646,26 @@ namespace ntt { const bool reverse_input = ordering == Ordering::kNN; const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN; bool is_normalize = is_inverse; + const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset; + const int n_twiddles = 1 << max_logn; + + // TODO: fuse BatchMulKernelDigReverse with input reorder (and normalize)? + if (is_on_coset && !is_inverse) { + BatchMulKernelDigReverse<<>>( + d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles, + arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output); + + d_input = d_output; + } if (reverse_input) { - // Note: fusing reorder with normalize for INTT + // Note: fused reorder and normalize (for INTT) const bool is_reverse_in_place = (d_input == d_output); if (is_reverse_in_place) { - reorder_digits_inplace_kernel<<>>( + reorder_digits_inplace_and_normalize_kernel<<>>( d_output, logn, is_dit, is_normalize, S::inv_log_size(logn)); } else { - reorder_digits_kernel<<>>( + reorder_digits_and_normalize_kernel<<>>( d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn)); } is_normalize = false; @@ -643,6 +676,12 @@ namespace ntt { d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse, is_normalize, is_dit, cuda_stream)); + if (is_on_coset && is_inverse) { + BatchMulKernelDigReverse<<>>( + d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, + arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output); + } + return CHK_LAST(); } @@ -666,6 +705,8 @@ namespace ntt { int batch_size, bool is_inverse, Ordering ordering, + curve_config::scalar_t* arbitrary_coset, + int coset_gen_index, cudaStream_t cuda_stream); } // namespace ntt diff --git a/icicle/appUtils/ntt/ntt.cu b/icicle/appUtils/ntt/ntt.cu index 781cfbd2d..755165b56 100644 --- a/icicle/appUtils/ntt/ntt.cu +++ b/icicle/appUtils/ntt/ntt.cu @@ -503,9 +503,8 @@ namespace ntt { // (heuristic) cutoff point where mixed-radix is faster than radix-2 const bool is_small_ntt = (logn < 16) && ((size_t)size * batch_size < (1 << 20)); - const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet - const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation - const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || is_on_coset || !is_NN; + const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation + const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || !is_NN; if (is_radix2_algorithm) { bool ct_butterfly = true; @@ -529,16 +528,16 @@ namespace ntt { reverse_input ? d_output : d_input, size, Domain::twiddles, Domain::max_size, batch_size, logn, dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output)); - if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream)); } else { // mixed-radix algorithm CHK_IF_RETURN(ntt::mixed_radix_ntt( d_input, d_output, Domain::twiddles, Domain::internal_twiddles, Domain::basic_twiddles, size, - Domain::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, stream)); + Domain::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, coset, coset_index, stream)); } if (!are_outputs_on_device) CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream)); + if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream)); if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream)); if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream)); if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream)); diff --git a/icicle/appUtils/ntt/ntt_impl.cuh b/icicle/appUtils/ntt/ntt_impl.cuh index 190f8d8a6..9e82a336e 100644 --- a/icicle/appUtils/ntt/ntt_impl.cuh +++ b/icicle/appUtils/ntt/ntt_impl.cuh @@ -28,6 +28,8 @@ namespace ntt { int batch_size, bool is_inverse, Ordering ordering, + S* arbitrary_coset, + int coset_gen_index, cudaStream_t cuda_stream); } // namespace ntt diff --git a/icicle/appUtils/ntt/tests/verification.cu b/icicle/appUtils/ntt/tests/verification.cu index cff741fb5..141ea1cab 100644 --- a/icicle/appUtils/ntt/tests/verification.cu +++ b/icicle/appUtils/ntt/tests/verification.cu @@ -35,11 +35,12 @@ int main(int argc, char** argv) cudaEvent_t icicle_start, icicle_stop, new_start, new_stop; float icicle_time, new_time; - int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size + int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 16; // assuming second input is the log-size int NTT_SIZE = 1 << NTT_LOG_SIZE; bool INPLACE = (argc > 2) ? atoi(argv[2]) : false; int INV = (argc > 3) ? atoi(argv[3]) : false; int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 32; + int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 1; const ntt::Ordering ordering = ntt::Ordering::kNN; const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN" @@ -48,8 +49,8 @@ int main(int argc, char** argv) : "RR"; printf( - "running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV, - BATCH_SIZE); + "running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d\n", NTT_LOG_SIZE, ordering_str, + INPLACE, INV, BATCH_SIZE, COSET_IDX); CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup) @@ -95,7 +96,10 @@ int main(int argc, char** argv) cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice)); } - // run ntt + for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) { + ntt_config.coset_gen = ntt_config.coset_gen * basic_root; + } + auto benchmark = [&](bool is_print, int iterations) -> cudaError_t { // NEW CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));