From 3e83820b7e9c1d5425801219acc010c332989eb5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 3 Jul 2023 10:53:18 -0700 Subject: [PATCH] Refactor and fix performance regression with GPU runtime checks (#1292) --- dace/codegen/common.py | 34 +++++++---------------------- dace/codegen/compiled_sdfg.py | 4 ++-- dace/codegen/tools/gpu_runtime.py | 36 +++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 28 deletions(-) create mode 100644 dace/codegen/tools/gpu_runtime.py diff --git a/dace/codegen/common.py b/dace/codegen/common.py index e8f2972c63..5dafc696cf 100644 --- a/dace/codegen/common.py +++ b/dace/codegen/common.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import ast from copy import deepcopy import ctypes.util @@ -6,6 +6,7 @@ from dace.sdfg import SDFG from dace.properties import CodeBlock from dace.codegen import cppunparse +from dace.codegen.tools import gpu_runtime from functools import lru_cache from io import StringIO import os @@ -146,7 +147,11 @@ def _try_execute(cmd: str) -> bool: 'to either "cuda" or "hip".') -def get_gpu_runtime_library() -> ctypes.CDLL: +@lru_cache() +def get_gpu_runtime() -> gpu_runtime.GPURuntime: + """ + Returns the GPU runtime library (CUDA / HIP) if exists. The result is cached for performance. + """ backend = get_gpu_backend() if backend == 'cuda': libpath = ctypes.util.find_library('cudart') @@ -165,27 +170,4 @@ def get_gpu_runtime_library() -> ctypes.CDLL: raise RuntimeError(f'GPU runtime library for {backend} not found. Please set the {envname} ' 'environment variable to point to the libraries.') - return ctypes.CDLL(libpath) - - -def get_gpu_runtime_error_string(err: int) -> str: - lib = get_gpu_runtime_library() - - # Obtain the error string - geterrorstring = getattr(lib, f'{get_gpu_backend()}GetErrorString') - geterrorstring.restype = ctypes.c_char_p - return geterrorstring(err).decode('utf-8') - - -def get_gpu_runtime_last_error() -> str: - lib = get_gpu_runtime_library() - - getlasterror = getattr(lib, f'{get_gpu_backend()}GetLastError') - res: int = getlasterror() - if res == 0: - return None - - # Obtain the error string - geterrorstring = getattr(lib, f'{get_gpu_backend()}GetErrorString') - geterrorstring.restype = ctypes.c_char_p - return geterrorstring(res).decode('utf-8') + return gpu_runtime.GPURuntime(backend, libpath) diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index 4538d6d9b4..ea1b9e9cb8 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -319,7 +319,7 @@ def finalize(self): def _get_error_text(self, result: Union[str, int]) -> str: if self.has_gpu_code: if isinstance(result, int): - result = common.get_gpu_runtime_error_string(result) + result = common.get_gpu_runtime().get_error_string(result) return (f'{result}. Consider enabling synchronous debugging mode (environment variable: ' 'DACE_compiler_cuda_syncdebug=1) to see where the issue originates from.') else: @@ -345,7 +345,7 @@ def __call__(self, *args, **kwargs): if self.has_gpu_code: # Optionally get errors from call try: - lasterror = common.get_gpu_runtime_last_error() + lasterror = common.get_gpu_runtime().get_last_error_string() except RuntimeError as ex: warnings.warn(f'Could not get last error from GPU runtime: {ex}') lasterror = None diff --git a/dace/codegen/tools/gpu_runtime.py b/dace/codegen/tools/gpu_runtime.py new file mode 100644 index 0000000000..1a2c4abcef --- /dev/null +++ b/dace/codegen/tools/gpu_runtime.py @@ -0,0 +1,36 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +GPU runtime testing functionality. Used for checking error codes after GPU-capable SDFG execution. +""" +import ctypes +from typing import Optional + + +class GPURuntime: + """ + GPU runtime object containing the library (CUDA / HIP) and some functions to query errors. + """ + + def __init__(self, backend_name: str, path: str) -> None: + self.backend = backend_name + self.library = ctypes.CDLL(path) + + # Prefetch runtime functions + self._geterrorstring = getattr(self.library, f'{self.backend}GetErrorString') + self._geterrorstring.restype = ctypes.c_char_p + self._getlasterror = getattr(self.library, f'{self.backend}GetLastError') + + def get_error_string(self, err: int) -> str: + # Obtain the error string + return self._geterrorstring(err).decode('utf-8') + + def get_last_error(self) -> int: + return self._getlasterror() + + def get_last_error_string(self) -> Optional[str]: + res: int = self._getlasterror() + if res == 0: + return None + + # Obtain the error string + return self.get_error_string(res)