diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 884723ed68..131c78008b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1589,6 +1589,8 @@ def trt_export( """ Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript. Currently, this API only supports converting models whose inputs are all tensors. + Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. + Review the TensorRT Support Matrix for which GPUs are supported. There are two ways to export a model: 1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript. diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a360f63dbd..d2d05fae22 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -505,7 +505,9 @@ def trt_compile( ) -> torch.nn.Module: """ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. - Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x. + NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. + Review the TensorRT Support Matrix for which GPUs are supported. Args: model: module to patch with TrtCompiler object. base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 79dc1f2304..8f2f400b5d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -107,6 +107,7 @@ InvalidPyTorchVersionError, OptionalImportError, allow_missing_reference, + compute_capabilities_after, damerau_levenshtein_distance, exact_version, get_full_type_name, diff --git a/monai/utils/module.py b/monai/utils/module.py index 1f7f8aecfc..d3f2ff09f2 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -634,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st if is_prerelease: return False return True + + +@functools.lru_cache(None) +def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool: + """ + Compute whether the current system GPU CUDA compute capability is after or equal to the specified version. + The current system GPU CUDA compute capability is determined by the first GPU in the system. + The compared version is a string in the form of "major.minor". + + Args: + major: major version number to be compared with. + minor: minor version number to be compared with. Defaults to 0. + current_ver_string: if None, the current system GPU CUDA compute capability will be used. + + Returns: + True if the current system GPU CUDA compute capability is greater than or equal to the specified version. + """ + if current_ver_string is None: + cuda_available = torch.cuda.is_available() + pynvml, has_pynvml = optional_import("pynvml") + if not has_pynvml: # assuming that the user has Ampere and later GPU + return True + if not cuda_available: + return False + else: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU + major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + current_ver_string = f"{major_c}.{minor_c}" + pynvml.nvmlShutdown() + + ver, has_ver = optional_import("packaging.version", name="parse") + if has_ver: + return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore + parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2) + while len(parts) < 2: + parts += ["0"] + c_major, c_minor = parts[:2] + c_mn = int(c_major), int(c_minor) + mn = int(major), int(minor) + return c_mn > mn diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 833a0ca1dc..835c8e5c1d 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -22,7 +22,13 @@ from monai.data import load_net_with_metadata from monai.networks import save_state from monai.utils import optional_import -from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import ( + SkipIfBeforeComputeCapabilityVersion, + command_line_tests, + skip_if_no_cuda, + skip_if_quick, + skip_if_windows, +) _, has_torchtrt = optional_import( "torch_tensorrt", @@ -47,6 +53,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestTRTExport(unittest.TestCase): def setUp(self): diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 5579539764..712d887c3b 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -20,7 +20,7 @@ from monai.networks import convert_to_trt from monai.networks.nets import UNet from monai.utils import optional_import -from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows _, has_torchtrt = optional_import( "torch_tensorrt", @@ -38,6 +38,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestConvertToTRT(unittest.TestCase): def setUp(self): diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 6df5d520bd..49404fdbbe 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -21,7 +21,13 @@ from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import ( + SkipIfAtLeastPyTorchVersion, + SkipIfBeforeComputeCapabilityVersion, + skip_if_no_cuda, + skip_if_quick, + skip_if_windows, +) trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -36,6 +42,7 @@ @skip_if_quick @unittest.skipUnless(trt_imported, "tensorrt is required") @unittest.skipUnless(polygraphy_imported, "polygraphy is required") +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestTRTCompile(unittest.TestCase): def setUp(self): diff --git a/tests/test_pytorch_version_after.py b/tests/test_version_after.py similarity index 72% rename from tests/test_pytorch_version_after.py rename to tests/test_version_after.py index 147707d2c0..b6cb741382 100644 --- a/tests/test_pytorch_version_after.py +++ b/tests/test_version_after.py @@ -15,9 +15,9 @@ from parameterized import parameterized -from monai.utils import pytorch_after +from monai.utils import compute_capabilities_after, pytorch_after -TEST_CASES = ( +TEST_CASES_PT = ( (1, 5, 9, "1.6.0"), (1, 6, 0, "1.6.0"), (1, 6, 1, "1.6.0", False), @@ -36,14 +36,30 @@ (1, 6, 1, "1.6.0+cpu", False), ) +TEST_CASES_SM = [ + # (major, minor, sm, expected) + (6, 1, "6.1", True), + (6, 1, "6.0", False), + (6, 0, "8.6", True), + (7, 0, "8", True), + (8, 6, "8", False), +] + class TestPytorchVersionCompare(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TEST_CASES_PT) def test_compare(self, a, b, p, current, expected=True): """Test pytorch_after with a and b""" self.assertEqual(pytorch_after(a, b, p, current), expected) +class TestComputeCapabilitiesAfter(unittest.TestCase): + + @parameterized.expand(TEST_CASES_SM) + def test_compute_capabilities_after(self, major, minor, sm, expected): + self.assertEqual(compute_capabilities_after(major, minor, sm), expected) + + if __name__ == "__main__": unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 77b53cebb8..2a00af50e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -47,7 +47,7 @@ from monai.networks import convert_to_onnx, convert_to_torchscript from monai.utils import optional_import from monai.utils.misc import MONAIEnvVars -from monai.utils.module import pytorch_after +from monai.utils.module import compute_capabilities_after, pytorch_after from monai.utils.tf32 import detect_default_tf32 from monai.utils.type_conversion import convert_data_type @@ -286,6 +286,20 @@ def __call__(self, obj): )(obj) +class SkipIfBeforeComputeCapabilityVersion: + """Decorator to be used if test should be skipped + with Compute Capability older than that given.""" + + def __init__(self, compute_capability_tuple): + self.min_version = compute_capability_tuple + self.version_too_old = not compute_capabilities_after(*compute_capability_tuple) + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}" + )(obj) + + def is_main_test_process(): ps = torch.multiprocessing.current_process() if not ps or not hasattr(ps, "name"):