Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed radix NTT coset support #368

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions icicle/appUtils/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace ntt {
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void
reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
reorder_digits_inplace_and_normalize_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
// launch N threads (per batch element)
// each thread starts from one index and calculates the corresponding group
Expand Down Expand Up @@ -66,15 +66,35 @@ namespace ntt {
}

template <typename E, typename S>
__launch_bounds__(64) __global__
void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t rd = tid;
uint32_t wr = ((tid >> log_size) << log_size) + dig_rev(tid & ((1 << log_size) - 1), log_size, dit);
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}

template <typename E, typename S>
static __global__ void BatchMulKernelDigReverse(
E* in_vec,
int n_elements,
int batch_size,
S* scalar_vec,
int step,
int n_scalars,
int logn,
bool digit_rev,
bool dit,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= n_elements * batch_size) return;
int64_t scalar_id = tid % n_elements;
if (digit_rev) scalar_id = dig_rev(tid, logn, dit);
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}

template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt64(
E* in,
Expand Down Expand Up @@ -604,6 +624,8 @@ namespace ntt {
int batch_size,
bool is_inverse,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
cudaStream_t cuda_stream)
{
CHK_INIT_IF_RETURN();
Expand All @@ -624,15 +646,26 @@ namespace ntt {
const bool reverse_input = ordering == Ordering::kNN;
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
bool is_normalize = is_inverse;
const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset;
const int n_twiddles = 1 << max_logn;

// TODO: fuse BatchMulKernelDigReverse with input reorder (and normalize)?
if (is_on_coset && !is_inverse) {
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);

d_input = d_output;
}

if (reverse_input) {
// Note: fusing reorder with normalize for INTT
// Note: fused reorder and normalize (for INTT)
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
Expand All @@ -643,6 +676,12 @@ namespace ntt {
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
is_normalize, is_dit, cuda_stream));

if (is_on_coset && is_inverse) {
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);
}

return CHK_LAST();
}

Expand All @@ -666,6 +705,8 @@ namespace ntt {
int batch_size,
bool is_inverse,
Ordering ordering,
curve_config::scalar_t* arbitrary_coset,
int coset_gen_index,
cudaStream_t cuda_stream);

} // namespace ntt
9 changes: 4 additions & 5 deletions icicle/appUtils/ntt/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,8 @@ namespace ntt {

// (heuristic) cutoff point where mixed-radix is faster than radix-2
const bool is_small_ntt = (logn < 16) && ((size_t)size * batch_size < (1 << 20));
const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || is_on_coset || !is_NN;
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || !is_NN;

if (is_radix2_algorithm) {
bool ct_butterfly = true;
Expand All @@ -529,16 +528,16 @@ namespace ntt {
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));

if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
} else { // mixed-radix algorithm
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
Domain<S>::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, stream));
Domain<S>::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, coset, coset_index, stream));
}

if (!are_outputs_on_device)
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));

if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
Expand Down
2 changes: 2 additions & 0 deletions icicle/appUtils/ntt/ntt_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace ntt {
int batch_size,
bool is_inverse,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
cudaStream_t cuda_stream);

} // namespace ntt
Expand Down
12 changes: 8 additions & 4 deletions icicle/appUtils/ntt/tests/verification.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ int main(int argc, char** argv)
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
float icicle_time, new_time;

int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 16; // assuming second input is the log-size
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
int INV = (argc > 3) ? atoi(argv[3]) : false;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 32;
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 1;

const ntt::Ordering ordering = ntt::Ordering::kNN;
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
Expand All @@ -48,8 +49,8 @@ int main(int argc, char** argv)
: "RR";

printf(
"running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV,
BATCH_SIZE);
"running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d\n", NTT_LOG_SIZE, ordering_str,
INPLACE, INV, BATCH_SIZE, COSET_IDX);

CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)

Expand Down Expand Up @@ -95,7 +96,10 @@ int main(int argc, char** argv)
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
}

// run ntt
for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) {
ntt_config.coset_gen = ntt_config.coset_gen * basic_root;
}

auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
// NEW
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
Expand Down
Loading