Skip to content

Commit

Permalink
Mixed-radix ntt coset support
Browse files Browse the repository at this point in the history
  • Loading branch information
yshekel committed Feb 12, 2024
1 parent c2c7b60 commit 5d5aa8f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 15 deletions.
53 changes: 47 additions & 6 deletions icicle/appUtils/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace ntt {
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void
reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
reorder_digits_inplace_and_normalize_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
// launch N threads (per batch element)
// each thread starts from one index and calculates the corresponding group
Expand Down Expand Up @@ -66,15 +66,35 @@ namespace ntt {
}

template <typename E, typename S>
__launch_bounds__(64) __global__
void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t rd = tid;
uint32_t wr = ((tid >> log_size) << log_size) + dig_rev(tid & ((1 << log_size) - 1), log_size, dit);
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}

template <typename E, typename S>
static __global__ void BatchMulKernelDigReverse(
E* in_vec,
int n_elements,
int batch_size,
S* scalar_vec,
int step,
int n_scalars,
int logn,
bool digit_rev,
bool dit,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= n_elements * batch_size) return;
int64_t scalar_id = tid % n_elements;
if (digit_rev) scalar_id = dig_rev(tid, logn, dit);
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}

template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt64(
E* in,
Expand Down Expand Up @@ -604,6 +624,8 @@ namespace ntt {
int batch_size,
bool is_inverse,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
cudaStream_t cuda_stream)
{
CHK_INIT_IF_RETURN();
Expand All @@ -624,15 +646,26 @@ namespace ntt {
const bool reverse_input = ordering == Ordering::kNN;
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
bool is_normalize = is_inverse;
const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset;
const int n_twiddles = 1 << max_logn;

// TODO: fuse BatchMulKernelDigReverse with input reorder (and normalize)?
if (is_on_coset && !is_inverse) {
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);

d_input = d_output;
}

if (reverse_input) {
// Note: fusing reorder with normalize for INTT
// Note: fused reorder and normalize (for INTT)
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
Expand All @@ -643,6 +676,12 @@ namespace ntt {
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
is_normalize, is_dit, cuda_stream));

if (is_on_coset && is_inverse) {
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);
}

return CHK_LAST();
}

Expand All @@ -666,6 +705,8 @@ namespace ntt {
int batch_size,
bool is_inverse,
Ordering ordering,
curve_config::scalar_t* arbitrary_coset,
int coset_gen_index,
cudaStream_t cuda_stream);

} // namespace ntt
9 changes: 4 additions & 5 deletions icicle/appUtils/ntt/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,8 @@ namespace ntt {

// (heuristic) cutoff point where mixed-radix is faster than radix-2
const bool is_small_ntt = (logn < 16) && ((size_t)size * batch_size < (1 << 20));
const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || is_on_coset || !is_NN;
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || !is_NN;

if (is_radix2_algorithm) {
bool ct_butterfly = true;
Expand All @@ -529,16 +528,16 @@ namespace ntt {
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));

if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
} else { // mixed-radix algorithm
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
Domain<S>::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, stream));
Domain<S>::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, coset, coset_index, stream));
}

if (!are_outputs_on_device)
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));

if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
Expand Down
2 changes: 2 additions & 0 deletions icicle/appUtils/ntt/ntt_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace ntt {
int batch_size,
bool is_inverse,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
cudaStream_t cuda_stream);

} // namespace ntt
Expand Down
12 changes: 8 additions & 4 deletions icicle/appUtils/ntt/tests/verification.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ int main(int argc, char** argv)
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
float icicle_time, new_time;

int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 16; // assuming second input is the log-size
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
int INV = (argc > 3) ? atoi(argv[3]) : false;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 32;
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 1;

const ntt::Ordering ordering = ntt::Ordering::kNN;
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
Expand All @@ -48,8 +49,8 @@ int main(int argc, char** argv)
: "RR";

printf(
"running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV,
BATCH_SIZE);
"running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d\n", NTT_LOG_SIZE, ordering_str,
INPLACE, INV, BATCH_SIZE, COSET_IDX);

CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)

Expand Down Expand Up @@ -95,7 +96,10 @@ int main(int argc, char** argv)
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
}

// run ntt
for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) {
ntt_config.coset_gen = ntt_config.coset_gen * basic_root;
}

auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
// NEW
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
Expand Down

0 comments on commit 5d5aa8f

Please sign in to comment.