diff --git a/numba_cuda/numba/cuda/cudadrv/nvrtc.py b/numba_cuda/numba/cuda/cudadrv/nvrtc.py index 0178345..9a1ae74 100644 --- a/numba_cuda/numba/cuda/cudadrv/nvrtc.py +++ b/numba_cuda/numba/cuda/cudadrv/nvrtc.py @@ -3,6 +3,7 @@ from numba.core import config from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError, NvrtcSupportError) +from cuda.cuda.cudadrv.driver import get_version import functools import os @@ -62,6 +63,14 @@ class NVRTC: NVVM interface. Initialization is protected by a lock and uses the standard (for Numba) open_cudalib function to load the NVRTC library. """ + + _CU12ONLY_PROTOTYPES = { + # nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet); + "nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)), + # nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto); + "nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p) + } + _PROTOTYPES = { # nvrtcResult nvrtcVersion(int *major, int *minor) 'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)), @@ -84,10 +93,6 @@ class NVRTC: 'nvrtcGetPTXSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)), # nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx); 'nvrtcGetPTX': (nvrtc_result, nvrtc_program, c_char_p), - # nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet); - "nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)), - # nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto); - "nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p), # nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog, # size_t *cubinSizeRet); 'nvrtcGetCUBINSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)), @@ -101,6 +106,9 @@ class NVRTC: 'nvrtcGetProgramLog': (nvrtc_result, nvrtc_program, c_char_p), } + if get_version() >= (12, 0): + _PROTOTYPES |= _CU12ONLY_PROTOTYPES + # Singleton reference __INSTANCE = None