Skip to content

Commit

Permalink
Allow CUDA source inputs compiled to LTOIR, and enable pynvjitlinker …
Browse files Browse the repository at this point in the history
…to link inputs that contains LTOIR (#62)

Adds functionality supporting kernel and FFI functions being JIT-compiled to LTOIR and link with LTO, allowing better optimization when foreign functions are used in Numba-cuda.

---------

Co-authored-by: Graham Markall <gmarkall@nvidia.com>
  • Loading branch information
isVoid and gmarkall authored Dec 6, 2024
1 parent 21bb1da commit 779782d
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 39 deletions.
2 changes: 1 addition & 1 deletion ci/test_conda_pynvjitlink.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ set -euo pipefail
if [ "${CUDA_VER%.*.*}" = "11" ]; then
CTK_PACKAGES="cudatoolkit"
else
CTK_PACKAGES="cuda-nvcc-impl cuda-nvrtc"
CTK_PACKAGES="cuda-nvcc-impl cuda-nvrtc cuda-cuobjdump"
fi

rapids-logger "Install testing dependencies"
Expand Down
50 changes: 36 additions & 14 deletions numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import subprocess
import tempfile


CUDA_TRIPLE = 'nvptx64-nvidia-cuda'


Expand Down Expand Up @@ -181,17 +180,7 @@ def get_ltoir(self, cc=None):

return ltoir

def get_cubin(self, cc=None):
cc = self._ensure_cc(cc)

cubin = self._cubin_cache.get(cc, None)
if cubin:
return cubin

linker = driver.Linker.new(
max_registers=self._max_registers, cc=cc, lto=self._lto
)

def _link_all(self, linker, cc, ignore_nonlto=False):
if linker.lto:
ltoir = self.get_ltoir(cc=cc)
linker.add_ltoir(ltoir)
Expand All @@ -200,11 +189,44 @@ def get_cubin(self, cc=None):
linker.add_ptx(ptx.encode())

for path in self._linking_files:
linker.add_file_guess_ext(path)
linker.add_file_guess_ext(path, ignore_nonlto)
if self.needs_cudadevrt:
linker.add_file_guess_ext(get_cudalib('cudadevrt', static=True))
linker.add_file_guess_ext(
get_cudalib('cudadevrt', static=True), ignore_nonlto
)

def get_cubin(self, cc=None):
cc = self._ensure_cc(cc)

cubin = self._cubin_cache.get(cc, None)
if cubin:
return cubin

if self._lto and config.DUMP_ASSEMBLY:
linker = driver.Linker.new(
max_registers=self._max_registers,
cc=cc,
additional_flags=["-ptx"],
lto=self._lto
)
# `-ptx` flag is meant to view the optimized PTX for LTO objects.
# Non-LTO objects are not passed to linker.
self._link_all(linker, cc, ignore_nonlto=True)

ptx = linker.get_linked_ptx().decode('utf-8')

print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-'))
print(ptx)
print('=' * 80)

linker = driver.Linker.new(
max_registers=self._max_registers,
cc=cc,
lto=self._lto
)
self._link_all(linker, cc, ignore_nonlto=False)
cubin = linker.complete()

self._cubin_cache[cc] = cubin
self._linkerinfo_cache[cc] = linker.info_log

Expand Down
105 changes: 103 additions & 2 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import traceback
import asyncio
import pathlib
import subprocess
import tempfile
import re
from itertools import product
from abc import ABCMeta, abstractmethod
from ctypes import (c_int, byref, c_size_t, c_char, c_char_p, addressof,
Expand All @@ -36,7 +39,7 @@
from .drvapi import API_PROTOTYPES
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
from .mappings import FILE_EXTENSION_MAP
from .linkable_code import LinkableCode
from .linkable_code import LinkableCode, LTOIR, Fatbin, Object
from numba.cuda.cudadrv import enums, drvapi, nvrtc

USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
Expand Down Expand Up @@ -2683,12 +2686,18 @@ def add_cu_file(self, path):
cu = f.read()
self.add_cu(cu, os.path.basename(path))

def add_file_guess_ext(self, path_or_code):
def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
"""
Add a file or LinkableCode object to the link. If a file is
passed, the type will be inferred from the extension. A LinkableCode
object represents a file already in memory.
When `ignore_nonlto` is set to true, do not add code that will not
be LTO-ed in the linking process. This is useful in inspecting the
LTO-ed portion of the PTX when linker is added with objects that can be
both LTO-ed and not LTO-ed.
"""

if isinstance(path_or_code, str):
ext = pathlib.Path(path_or_code).suffix
if ext == '':
Expand All @@ -2704,6 +2713,26 @@ def add_file_guess_ext(self, path_or_code):
"Don't know how to link file with extension "
f"{ext}"
)

if ignore_nonlto:
warn_and_return = False
if kind in (
FILE_EXTENSION_MAP["fatbin"], FILE_EXTENSION_MAP["o"]
):
entry_types = inspect_obj_content(path_or_code)
if "nvvm" not in entry_types:
warn_and_return = True
elif kind != FILE_EXTENSION_MAP["ltoir"]:
warn_and_return = True

if warn_and_return:
warnings.warn(
f"Not adding {path_or_code} as it is not "
"optimizable at link time, and `ignore_nonlto == "
"True`."
)
return

self.add_file(path_or_code, kind)
return
else:
Expand All @@ -2716,6 +2745,25 @@ def add_file_guess_ext(self, path_or_code):
if path_or_code.kind == "cu":
self.add_cu(path_or_code.data, path_or_code.name)
else:
if ignore_nonlto:
warn_and_return = False
if isinstance(path_or_code, (Fatbin, Object)):
with tempfile.NamedTemporaryFile("w") as fp:
fp.write(path_or_code.data)
entry_types = inspect_obj_content(fp.name)
if "nvvm" not in entry_types:
warn_and_return = True
elif not isinstance(path_or_code, LTOIR):
warn_and_return = True

if warn_and_return:
warnings.warn(
f"Not adding {path_or_code.name} as it is not "
"optimizable at link time, and `ignore_nonlto == "
"True`."
)
return

self.add_data(
path_or_code.data, path_or_code.kind, path_or_code.name
)
Expand Down Expand Up @@ -3065,6 +3113,28 @@ def add_file(self, path, kind):
name = pathlib.Path(path).name
self.add_data(data, kind, name)

def add_cu(self, cu, name):
"""Add CUDA source in a string to the link. The name of the source
file should be specified in `name`."""
with driver.get_active_context() as ac:
dev = driver.get_device(ac.devnum)
cc = dev.compute_capability

program, log = nvrtc.compile(cu, name, cc, ltoir=self.lto)

if not self.lto and config.DUMP_ASSEMBLY:
print(("ASSEMBLY %s" % name).center(80, "-"))
print(program)
print("=" * 80)

suffix = ".ltoir" if self.lto else ".ptx"
program_name = os.path.splitext(name)[0] + suffix
# Link the program's PTX or LTOIR using the normal linker mechanism
if self.lto:
self.add_ltoir(program, program_name)
else:
self.add_ptx(program.encode(), program_name)

def add_data(self, data, kind, name):
if kind == FILE_EXTENSION_MAP["cubin"]:
fn = self._linker.add_cubin
Expand All @@ -3086,6 +3156,12 @@ def add_data(self, data, kind, name):
except NvJitLinkError as e:
raise LinkerError from e

def get_linked_ptx(self):
try:
return self._linker.get_linked_ptx()
except NvJitLinkError as e:
raise LinkerError from e

def complete(self):
try:
return self._linker.get_linked_cubin()
Expand Down Expand Up @@ -3361,3 +3437,28 @@ def get_version():
Return the driver version as a tuple of (major, minor)
"""
return driver.get_version()


def inspect_obj_content(objpath: str):
"""
Given path to a fatbin or object, use `cuobjdump` to examine its content
Return the set of entries in the object.
"""
code_types :set[str] = set()

try:
out = subprocess.run(["cuobjdump", objpath], check=True,
capture_output=True)
except FileNotFoundError as e:
msg = ("cuobjdump has not been found. You may need "
"to install the CUDA toolkit and ensure that "
"it is available on your PATH.\n")
raise RuntimeError(msg) from e

objtable = out.stdout.decode('utf-8')
entry_pattern = r"Fatbin (.*) code"
for line in objtable.split("\n"):
if match := re.match(entry_pattern, line):
code_types.add(match.group(1))

return code_types
41 changes: 37 additions & 4 deletions numba_cuda/numba/cuda/cudadrv/nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class NVRTC:
NVVM interface. Initialization is protected by a lock and uses the standard
(for Numba) open_cudalib function to load the NVRTC library.
"""

_CU12ONLY_PROTOTYPES = {
# nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
"nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
# nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
"nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p)
}

_PROTOTYPES = {
# nvrtcResult nvrtcVersion(int *major, int *minor)
'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
Expand Down Expand Up @@ -110,6 +118,10 @@ def __new__(cls):
cls.__INSTANCE = None
raise NvrtcSupportError("NVRTC cannot be loaded") from e

from numba.cuda.cudadrv.runtime import get_version
if get_version() >= (12, 0):
inst._PROTOTYPES |= inst._CU12ONLY_PROTOTYPES

# Find & populate functions
for name, proto in inst._PROTOTYPES.items():
func = getattr(lib, name)
Expand Down Expand Up @@ -208,17 +220,31 @@ def get_ptx(self, program):

return ptx.value.decode()

def get_lto(self, program):
"""
Get the compiled LTOIR as a Python bytes object.
"""
lto_size = c_size_t()
self.nvrtcGetLTOIRSize(program.handle, byref(lto_size))

lto = b" " * lto_size.value
self.nvrtcGetLTOIR(program.handle, lto)

return lto

def compile(src, name, cc):

def compile(src, name, cc, ltoir=False):
"""
Compile a CUDA C/C++ source to PTX for a given compute capability.
Compile a CUDA C/C++ source to PTX or LTOIR for a given compute capability.
:param src: The source code to compile
:type src: str
:param name: The filename of the source (for information only)
:type name: str
:param cc: A tuple ``(major, minor)`` of the compute capability
:type cc: tuple
:param ltoir: Compile into LTOIR if True, otherwise into PTX
:type ltoir: bool
:return: The compiled PTX and compilation log
:rtype: tuple
"""
Expand All @@ -242,6 +268,9 @@ def compile(src, name, cc):
numba_include = f'-I{numba_cuda_path}'
options = [arch, *cuda_include, numba_include, '-rdc', 'true']

if ltoir:
options.append("-dlto")

if nvrtc.get_version() < (12, 0):
options += ["-std=c++17"]

Expand All @@ -261,5 +290,9 @@ def compile(src, name, cc):
msg = (f"NVRTC log messages whilst compiling {name}:\n\n{log}")
warnings.warn(msg)

ptx = nvrtc.get_ptx(program)
return ptx, log
if ltoir:
ltoir = nvrtc.get_lto(program)
return ltoir, log
else:
ptx = nvrtc.get_ptx(program)
return ptx, log
Loading

0 comments on commit 779782d

Please sign in to comment.