Skip to content

Commit

Permalink
Disable sp24 tests on H100 (facebookresearch#1073)
Browse files Browse the repository at this point in the history
  • Loading branch information
danthe3rd authored Apr 3, 2024
1 parent d88b97e commit 20d6242
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 26 deletions.
51 changes: 26 additions & 25 deletions tests/test_sparsity24.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
torch_compile_tests = pytest.mark.skipif(
torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+"
)
cuda_sm80_only = pytest.mark.skipif(
compute_capability < (8, 0), reason="requires sm80+"
requires_sp24 = pytest.mark.skipif(compute_capability < (8, 0), reason="requires sm80+")
requires_sp24_gemm = pytest.mark.skipif(
compute_capability != (8, 0), reason="requires sm80"
)
parametrize_dtype = pytest.mark.parametrize(
"dtype", [torch.float16, torch.bfloat16], ids=["f16", "bf16"]
Expand Down Expand Up @@ -67,7 +68,7 @@ def test_sparse24_largest_mask_2d() -> None:
]


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@parametrize_backend
def test_autocast(dtype, backend: str) -> None:
Expand All @@ -89,7 +90,7 @@ def test_autocast(dtype, backend: str) -> None:
assert_allclose(y, y_ac, "gemm", **atol_rtol_kw[dtype])


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
def test_sparse24_causal1122(dtype) -> None:
inp = torch.tensor(
Expand All @@ -109,7 +110,7 @@ def test_sparse24_causal1122(dtype) -> None:
]


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@parametrize_backend
def test_sparse24_largest_abs_values_greedy(dtype, backend) -> None:
Expand Down Expand Up @@ -286,7 +287,7 @@ def test_pack_tensor_according_to_mask() -> None:
assert line_packed == line_packed_expected


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
def test_sp24_gemm(dtype) -> None:
M, N, K = 32, 32, 64
Expand Down Expand Up @@ -361,7 +362,7 @@ def expect_match(i, j, line):
expect_match(1, 1, 3) # T5


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@parametrize_backend
def test_pack_both_ways_meta_correctness(dtype, backend) -> None:
Expand All @@ -386,7 +387,7 @@ def test_pack_both_ways_meta_correctness(dtype, backend) -> None:
assert_allclose(ref_gemm, pack_gemm, msg="sp24 GEMM", **atol_rtol_kw[dtype])


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
def test_pack_both_ways_id(dtype) -> None:
N = 512
Expand Down Expand Up @@ -502,7 +503,7 @@ def test_sp24_api_different_pattern_transposed(dtype) -> None:
assert torch.allclose(sxt2.packed_t, sxt.packed_t)


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@parametrize_backend
def test_sp24_transpose_invariant(dtype, backend) -> None:
Expand Down Expand Up @@ -543,7 +544,7 @@ def gen4x4():
assert_allclose(a_t_s._sp24_to_dense().t(), a)


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
def test_sp24_matmuls(dtype) -> None:
M, N, K = 64, 256, 1024
Expand All @@ -564,7 +565,7 @@ def test_sp24_matmuls(dtype) -> None:
)


@cuda_sm80_only
@requires_sp24
def test_sp24_matmuls_mat_vec() -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([128], device="cuda", dtype=torch.float16)
Expand All @@ -575,7 +576,7 @@ def test_sp24_matmuls_mat_vec() -> None:
assert_allclose(a_s @ b, (a * a_m) @ b, msg="sp@dense", **atol_rtol_kw[a.dtype])


@cuda_sm80_only
@requires_sp24
def test_sp24_matmuls_bmm() -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
Expand All @@ -591,7 +592,7 @@ def sparsify24_dense(tensor: torch.Tensor):
return m * tensor


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@pytest.mark.parametrize("act", [F.gelu, F.relu])
def test_sp24_api_mlp_act24_correctness(dtype, act) -> None:
Expand Down Expand Up @@ -640,7 +641,7 @@ def test_sp24_api_mlp_act24_correctness(dtype, act) -> None:
assert_allclose(grad_calc, grad_ref, msg=grad_name, **atol_rtol_kw[dtype])


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
def test_sp24_api_swiglu_correctness(dtype) -> None:
B, in_ft, hid_ft, out_ft = 256, 2048, 6144 // 2, 2048
Expand Down Expand Up @@ -695,7 +696,7 @@ def test_sp24_api_swiglu_correctness(dtype) -> None:
assert_allclose(grad_calc, grad_ref, msg=grad_name, **atol_rtol_kw[dtype])


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@pytest.mark.parametrize("M", [1, 8, 26, 31, 32, 48, 63])
def test_not_aligned(dtype, M):
Expand All @@ -708,7 +709,7 @@ def test_not_aligned(dtype, M):
assert_allclose(As @ B, A @ B, msg="not aligned", **atol_rtol_kw[dtype])


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@pytest.mark.parametrize("input_rowmajor", [True, False])
def test_sparsify24_like_dense(dtype, input_rowmajor):
Expand All @@ -724,7 +725,7 @@ def test_sparsify24_like_dense(dtype, input_rowmajor):
)


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_dtype
@parametrize_backend
def test_sparsify24_weights(dtype, backend):
Expand Down Expand Up @@ -769,7 +770,7 @@ def _workaround_cusparselt_internal_error() -> None:
out.backward(out)


@cuda_sm80_only
@requires_sp24
@parametrize_dtype
@pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt")
@pytest.mark.parametrize("bias", [False, True], ids=["", "bias"])
Expand Down Expand Up @@ -837,7 +838,7 @@ def test_linearw24(dtype, bias: bool, aligned: bool, amp: bool) -> None:
)


@cuda_sm80_only
@requires_sp24
@pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt")
def test_wrong_alignment_error_message() -> None:
A = torch.randn([128, 128], device="cuda", dtype=torch.float16)
Expand All @@ -847,7 +848,7 @@ def test_wrong_alignment_error_message() -> None:
A @ B


@cuda_sm80_only
@requires_sp24
@pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt")
def test_min_alignment() -> None:
A = torch.randn([128, 128], device="cuda", dtype=torch.float16)
Expand All @@ -856,7 +857,7 @@ def test_min_alignment() -> None:
assert_allclose(A @ B, A._sp24_to_dense() @ B, "output", **atol_rtol_kw[A.dtype])


@cuda_sm80_only
@requires_sp24
@pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt")
def test_wrong_dtype_error_message() -> None:
A = torch.randn([128, 128], device="cuda", dtype=torch.float16)
Expand All @@ -866,7 +867,7 @@ def test_wrong_dtype_error_message() -> None:
A @ B


@cuda_sm80_only
@requires_sp24_gemm
@parametrize_backend
@pytest.mark.parametrize("with_bias", [False, True])
def test_linear_dispatch_inference_mode(backend: str, with_bias: bool) -> None:
Expand Down Expand Up @@ -906,7 +907,7 @@ def test_sp24_meta() -> None:


@torch_compile_tests
@cuda_sm80_only
@requires_sp24_gemm
@parametrize_backend
def test_sp24_compile(backend) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
Expand Down Expand Up @@ -949,7 +950,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@cuda_sm80_only
@requires_sp24_gemm
@torch_compile_tests
@pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt")
def test_linearw24_block_compile() -> None:
Expand Down Expand Up @@ -986,7 +987,7 @@ def test_linearw24_block_compile() -> None:
assert_allclose(param_c.grad, param_ref.grad, param_name, **atol_rtol_kw[dtype])


@cuda_sm80_only
@requires_sp24
@pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt")
def test_sp24_ste():
x = torch.randn([512, 512], dtype=torch.float16, device="cuda", requires_grad=True)
Expand Down
12 changes: 11 additions & 1 deletion xformers/ops/sp24.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,24 @@ class Sp24GemmCusplt(BaseOperator):

def _has_cusparseLt() -> bool:
available = _cusplt_version >= (0, 4, 0)
if available and _cusplt_version < (0, 5, 0):
if not available:
return False
if _cusplt_version < (0, 5, 0):
# Version 0.5.0 has much better perf because it can fuse the
# transpose within the GEMM epilogue
warnings.warn(
f"You have cusparseLt version {_cusplt_version_str} "
f"but you get better performance with v0.5.0+ if "
f"you replace the .so file ({_get_cusparselt_lib()})"
)

# Sm90 added in 6.0
compute_capability = (0, 0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
if _cusplt_version < (6, 0, 0):
if compute_capability >= (9, 0):
return False
return available


Expand Down

0 comments on commit 20d6242

Please sign in to comment.