Skip to content

Commit

Permalink
[ROCm] Prevent accidental enablement of efficient attention. (#134531)
Browse files Browse the repository at this point in the history
[ROCm] Prevent accidental enablement of efficient attention. (#133331)

Currently Efficient attention and Flash attention share the same set of GPU
kernels on ROCM and have common limitations on head sizes.

Fixes #132004

Pull Request resolved: #133331
Approved by: https://github.com/malfet, https://github.com/jithunnair-amd

(cherry picked from commit 46ecc67)

Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
  • Loading branch information
pytorchbot and xinyazhang committed Aug 27, 2024
1 parent e0ddbff commit 6a79d4a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
7 changes: 6 additions & 1 deletion aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
check_all_tensors_on_device,
check_mem_efficient_hardware_support,
check_tensor_shapes,
check_head_dim_size_mem_efficient);
#ifdef USE_ROCM
check_head_dim_size_flash
#else
check_head_dim_size_mem_efficient
#endif
);
for (auto& constraint : general_constraints) {
if (!constraint(params, debug)) {
return false;
Expand Down
7 changes: 5 additions & 2 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,8 @@ def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
dtype = torch.float16
make_tensor = partial(torch.rand, device=device, dtype=dtype)
size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257)
if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels
size = SdpaShape(2, 2, 3, 257)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
Expand Down Expand Up @@ -1499,8 +1501,9 @@ def test_unaligned_tensors(self, device):
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor()
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
with ctxmgr:
torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)

@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
Expand Down

0 comments on commit 6a79d4a

Please sign in to comment.