diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 610f5d2733246e..25ff034ffbfc0f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -87,6 +87,7 @@ subtest, TEST_WITH_ASAN, TEST_WITH_ROCM, + HAS_HIPCC, ) from torch.utils import _pytree as pytree from torch.utils._python_dispatch import TorchDispatchMode @@ -927,6 +928,7 @@ def test_aoti_eager_support_str(self): self.assertEqual(ref_value, res_value) @skipCUDAIf(not SM80OrLater, "Requires sm80") + @skipCUDAIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler") @skip_if_halide # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_cache_hit(self): @@ -970,6 +972,7 @@ def test_aoti_eager_cache_hit(self): self.assertEqual(ref_value, res_value) @skipCUDAIf(not SM80OrLater, "Requires sm80") + @skipCUDAIf(TEST_WITH_ROCM and not HAS_HIPCC, "ROCm requires hipcc compiler") @skip_if_halide # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_with_persistent_cache(self): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 4fb13abb5b96d3..ed904a696ca36c 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -96,6 +96,7 @@ from torch.testing._comparison import not_close_error_metas from torch.testing._internal.common_dtype import get_all_dtypes from torch.utils._import_utils import _check_module_exists +from torch.utils.cpp_extension import ROCM_HOME import torch.utils._pytree as pytree try: import pytest @@ -106,6 +107,8 @@ MI300_ARCH = ("gfx940", "gfx941", "gfx942") +HAS_HIPCC = torch.version.hip is not None and ROCM_HOME is not None and shutil.which('hipcc') is not None + def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs)