Skip to content

Commit

Permalink
Make Transformers compilable by C++17 (pytorch#90389)
Browse files Browse the repository at this point in the history
`register` keyword is removed in C++17, but keeping it there under ifdef
as I have not measured the perf implication on older compiler, though
there shouldn't be any: all modern compilers supposed to downright
ignore it.

This code originates from facebookresearch/xformers#375 will propose similar PR to remove register keyword usage to that repo.

Yet another thing discovered while working on pytorch#85969

Pull Request resolved: pytorch#90389
Approved by: https://github.com/drisspg
  • Loading branch information
malfet authored and kulinseth committed Dec 9, 2022
1 parent 974a144 commit d6f0a59
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,11 @@ struct AttentionBackwardKernel {
static CUTLASS_DEVICE void kernel(Params& p_) {
// Hint to nvcc to store points & tensor shapes in registers
// as we use them a lot
#if __cplusplus < 201703L
register const Params p = p_;
#else
const Params p = p_;
#endif

extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
Expand Down Expand Up @@ -721,7 +725,11 @@ struct AttentionBackwardKernel {
__syncthreads();
}

#if __cplusplus < 201703L
OutputFragments register output_frags;
#else
OutputFragments output_frags;
#endif
int32_t key_start = 0;
int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ;
for (; key_start < key_end; key_start += kBlockSizeJ) {
Expand Down

0 comments on commit d6f0a59

Please sign in to comment.