Skip to content

Commit

Permalink
Back out "Add BF16 support in group_index_select_2d" (#2326)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2326

Original commit changeset: a7bde01cef19

Original Phabricator Diff: D53445651

Reviewed By: yjhao

Differential Revision: D53622702

fbshipit-source-id: c95514aa3901b49c08c691481665cc181c5f8cb3
  • Loading branch information
sryap authored and facebook-github-bot committed Feb 9, 2024
1 parent bc85d73 commit eb3c304
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 29 deletions.
28 changes: 3 additions & 25 deletions fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/

#if (defined(USE_ROCM))
#include <hip/hip_bfloat16.h>
#elif ( \
(defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
#include <cuda_bf16.h>
#endif
#include "common.cuh"

using Tensor = at::Tensor;
Expand Down Expand Up @@ -92,23 +85,8 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
if (USE_INDEX_SELECT) {
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
if constexpr (std::is_same_v<scalar_t, at::BFloat16>) {
#if !( \
defined(USE_ROCM) || \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
atomicAdd(
reinterpret_cast<__nv_bfloat16*>(&output[idx * num_cols + i]),
*reinterpret_cast<const __nv_bfloat16*>(
&input[row * num_cols + i]));
#else
CUDA_KERNEL_ASSERT(
false && "atomicAdd __nv_bfloat16 is not supported");
#endif
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[row * num_cols + i]);
}
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[row * num_cols + i]);
}
}
}
Expand Down Expand Up @@ -163,7 +141,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda(

AT_DISPATCH_INDEX_TYPES(
indices_scalar_type, "group_index_select_2d_wrapper_1", [&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_scalar_type, "group_index_select_2d_wrapper_2", [&] {
if (use_index_select) {
if (use_var_cols) {
Expand Down
6 changes: 2 additions & 4 deletions fbgemm_gpu/test/sparse/index_select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def test_index_select_dim0(
num_indices=st.integers(1, 32),
max_num_input_rows=st.integers(1, 32),
shape=st.lists(st.integers(1, 32), min_size=1, max_size=2),
# TODO: Add torch.bfloat16
dtype=st.sampled_from([torch.float, torch.half]),
dtype=st.sampled_from([torch.float, torch.half, torch.double]),
use_cpu=st.booleans() if gpu_available else st.just(True),
num_groups=st.integers(1, 32),
use_var_cols=st.booleans(),
Expand Down Expand Up @@ -216,7 +215,6 @@ def compare_tensor_groups(
f"FAILED: group {i} {tensor_type} ({dtype}), "
f"input shape {input_group[i].shape}, indices "
f"{indices_group[i]}, test {test}, ref {ref}"
f"input {grad_group[i]}"
)
assert (
passed
Expand All @@ -231,7 +229,7 @@ def compare_tensor_groups(
# pyre-ignore [6]
[i.grad for i in input_ref_group],
"gradient",
{"rtol": 1e-02, "atol": 1e-02} if dtype != torch.float else {},
{"rtol": 1e-02, "atol": 1e-02} if dtype == torch.half else {},
)

@given(
Expand Down

0 comments on commit eb3c304

Please sign in to comment.