-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new annotation based on triton.compile to replace the use of jit #106
Labels
performance
make things faster, always
Comments
pommedeterresautee
changed the title
Use the new API of Triton to reduce CPU overhead
Use the new API of Triton to reduce CPU overhead and speed up warmup step
Oct 16, 2022
ayoub-louati
changed the title
Use the new API of Triton to reduce CPU overhead and speed up warmup step
Add new annotation based on triton.compile to replace the use of jit
Feb 6, 2023
Did some simple XP, get a 33% CPU overhead reduction. Better but still a magnitude slower than calling CUDA code. Before After jit code from __future__ import annotations, division
import ast
import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
import torch
import triton
from triton.utils import MockTensor
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
T = TypeVar('T')
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
class DependenciesFinder(ast.NodeVisitor):
"""
This AST visitor is used to find dependencies of a JITFunction. This can
be used to invalidate a JITFunction's hash when its source code -- or
that of its dependencies -- changes.
"""
def __init__(self, globals, src) -> None:
super().__init__()
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
self.globals = globals
def visit_Name(self, node):
return self.globals.get(node.id, None)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
return None
return getattr(lhs, node.attr)
def visit_Call(self, node):
func = self.visit(node.func)
if func is None:
return
if inspect.isbuiltin(func):
return
if func.__module__ and func.__module__.startswith('triton.'):
return
assert isinstance(func, JITFunction)
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
with open(triton.compiler.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@staticmethod
def _key_of(arg):
if hasattr(arg, "dtype"):
return arg.dtype
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -2**31 <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**31 <= arg and arg <= 2**32 - 1:
return "u32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return 'fp32'
elif arg is None:
return None
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
return (arg.data_ptr() % JITFunction.divisibility == 0)
elif isinstance(arg, int):
return (arg % 16 == 0, arg == 1)
return (arg is None, )
def _get_config(self, *args):
def is_divisible_by_16(x):
if hasattr(x, "data_ptr"):
return x.data_ptr() % JITFunction.divisibility == 0
elif isinstance(x, int):
return x % JITFunction.divisibility == 0
if x is None:
return True
return False
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
@staticmethod
def _type_of(key):
if isinstance(key, (torch.dtype, triton.language.dtype)):
ty = {
torch.bool: 'i1',
torch.float16: 'fp16',
torch.bfloat16: 'bf16',
torch.float32: 'fp32',
torch.float64: 'fp64',
torch.uint8: 'u8',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
triton.language.uint8: 'u8',
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
triton.language.float8: 'fp8',
triton.language.float16: 'fp16',
triton.language.bfloat16: 'bf16',
triton.language.float32: 'fp32',
}[key]
return f'*{ty}'
if key is None:
return '*i8'
assert isinstance(key, str)
return key
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
return signature
def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
def __init__(self, module, name):
self.module = module
self.name = name
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
specializations = []
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
f'else (False,)']
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
# sig_key = {sig_keys},
# constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
# spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
# key = (version_key, sig_key, constexpr_key, spec_key)
# if not extern_libs is None:
# key = (key, tuple(extern_libs.items()))
key = self.fn.__name__
# assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
# if callable(grid):
# grid = grid({{{grid_args}}})
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] #if grid_size > 1 else 1
grid_2 = grid[2] #if grid_size > 2 else 1
# device = 0
# torch.cuda.set_device(device)
# if stream is None and not warmup:
stream = get_cuda_stream(device)
try:
bin = cache[device][key]
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, None, None, None, {args})
return bin
# kernel not cached -- compile
except KeyError:
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
"""
torch.cuda.set_device(0)
stream = get_cuda_stream(0)
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
"cache": self.cache, "triton": triton, "torch": torch, "device": 0}
exec(src, scope)
return scope[self.fn.__name__]
def __init__(self, fn, version=None, do_not_specialize=None):
self.fn = fn
self.module = fn.__module__
self.version = version
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
# cache of just-in-time compiled kernels
self.cache = defaultdict(dict)
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# index of constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# launcher
self.run = self._make_launcher()
# re-use docs of wrapped function
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
def cache_key(self):
# TODO : hash should be attribute of `self`
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
return self.hash
def warmup(self, *args, **kwargs):
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
# - when kernel decorators change, cached kernel
# needs to be cleared
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
# - when `.src` attribute is set, cache path needs
# to be reinitialized
if name == 'src':
self.hash = None
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
# -----------------------------------------------------------------------------
# `jit` decorator
# -----------------------------------------------------------------------------
@overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
)
if fn is not None:
return decorator(fn)
else:
return decorator
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.') test triton_implementaitons = {
"classic": lambda args: kernel123[(1, 1, 1)](*args),
"direct": lambda args: kernel123.run(*args, grid=(1, 1, 1))
}
@pytest.mark.parametrize("implementation", ["classic", "direct"])
def test_triton(
benchmark,
implementation: str,
):
X = torch.ones((10,), dtype=torch.float32, device="cuda")
Z = torch.zeros((10,), dtype=torch.float32, device="cuda")
fn = triton_implementaitons[implementation]
benchmark(fn, [X, 1, Z, 1, 16, 16]) |
Draft
@pommedeterresautee can you please add your |
sorry I failed the copy paste import time
import torch
import triton
import triton.language as tl
from kernl.autotune import autotune
# triton kernel
@triton.jit
def kernel123(X, stride_xm,
Z, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
# ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Since triton-lang/triton#644 a new API is available which is supposed to provide significantly lower overhead.
Replace using
triton.jit()
annotation by a simplified call totriton.compile()
.It has been done in inductor project in this PR: pytorch/torchdynamo#1338
The text was updated successfully, but these errors were encountered: