Skip to content

Commit

Permalink
Refactor and fix performance regression with GPU runtime checks (#1292)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Jul 3, 2023
1 parent e66392f commit 3e83820
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 28 deletions.
34 changes: 8 additions & 26 deletions dace/codegen/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import ast
from copy import deepcopy
import ctypes.util
from dace import config, data, dtypes, sdfg as sd, symbolic
from dace.sdfg import SDFG
from dace.properties import CodeBlock
from dace.codegen import cppunparse
from dace.codegen.tools import gpu_runtime
from functools import lru_cache
from io import StringIO
import os
Expand Down Expand Up @@ -146,7 +147,11 @@ def _try_execute(cmd: str) -> bool:
'to either "cuda" or "hip".')


def get_gpu_runtime_library() -> ctypes.CDLL:
@lru_cache()
def get_gpu_runtime() -> gpu_runtime.GPURuntime:
"""
Returns the GPU runtime library (CUDA / HIP) if exists. The result is cached for performance.
"""
backend = get_gpu_backend()
if backend == 'cuda':
libpath = ctypes.util.find_library('cudart')
Expand All @@ -165,27 +170,4 @@ def get_gpu_runtime_library() -> ctypes.CDLL:
raise RuntimeError(f'GPU runtime library for {backend} not found. Please set the {envname} '
'environment variable to point to the libraries.')

return ctypes.CDLL(libpath)


def get_gpu_runtime_error_string(err: int) -> str:
lib = get_gpu_runtime_library()

# Obtain the error string
geterrorstring = getattr(lib, f'{get_gpu_backend()}GetErrorString')
geterrorstring.restype = ctypes.c_char_p
return geterrorstring(err).decode('utf-8')


def get_gpu_runtime_last_error() -> str:
lib = get_gpu_runtime_library()

getlasterror = getattr(lib, f'{get_gpu_backend()}GetLastError')
res: int = getlasterror()
if res == 0:
return None

# Obtain the error string
geterrorstring = getattr(lib, f'{get_gpu_backend()}GetErrorString')
geterrorstring.restype = ctypes.c_char_p
return geterrorstring(res).decode('utf-8')
return gpu_runtime.GPURuntime(backend, libpath)
4 changes: 2 additions & 2 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def finalize(self):
def _get_error_text(self, result: Union[str, int]) -> str:
if self.has_gpu_code:
if isinstance(result, int):
result = common.get_gpu_runtime_error_string(result)
result = common.get_gpu_runtime().get_error_string(result)
return (f'{result}. Consider enabling synchronous debugging mode (environment variable: '
'DACE_compiler_cuda_syncdebug=1) to see where the issue originates from.')
else:
Expand All @@ -345,7 +345,7 @@ def __call__(self, *args, **kwargs):
if self.has_gpu_code:
# Optionally get errors from call
try:
lasterror = common.get_gpu_runtime_last_error()
lasterror = common.get_gpu_runtime().get_last_error_string()
except RuntimeError as ex:
warnings.warn(f'Could not get last error from GPU runtime: {ex}')
lasterror = None
Expand Down
36 changes: 36 additions & 0 deletions dace/codegen/tools/gpu_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
"""
GPU runtime testing functionality. Used for checking error codes after GPU-capable SDFG execution.
"""
import ctypes
from typing import Optional


class GPURuntime:
"""
GPU runtime object containing the library (CUDA / HIP) and some functions to query errors.
"""

def __init__(self, backend_name: str, path: str) -> None:
self.backend = backend_name
self.library = ctypes.CDLL(path)

# Prefetch runtime functions
self._geterrorstring = getattr(self.library, f'{self.backend}GetErrorString')
self._geterrorstring.restype = ctypes.c_char_p
self._getlasterror = getattr(self.library, f'{self.backend}GetLastError')

def get_error_string(self, err: int) -> str:
# Obtain the error string
return self._geterrorstring(err).decode('utf-8')

def get_last_error(self) -> int:
return self._getlasterror()

def get_last_error_string(self) -> Optional[str]:
res: int = self._getlasterror()
if res == 0:
return None

# Obtain the error string
return self.get_error_string(res)

0 comments on commit 3e83820

Please sign in to comment.