Skip to content

Commit

Permalink
[FRONTEND] Allow tl.{u}int{width} annotations to bypass opportunist…
Browse files Browse the repository at this point in the history
…ic value-based JIT-specialization (#3102)

Triton's JIT currently specialize all integer arguments to `int32` if
they "fit", and `int64` otherwise. This PR gives users the possibility
to opt-out from this kind of specialization to avoid compiling too many
kernels.
  • Loading branch information
ptillet authored Feb 10, 2024
1 parent 00c144e commit b6e24b6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 24 deletions.
35 changes: 32 additions & 3 deletions python/test/unit/language/test_annotations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
from __future__ import annotations

import torch

import triton
import triton.language as tl
import pytest


def annotated_function(return_type=None, **arg_types):
"""A decorator to add annotations to a function."""

def decorator(func):
func.__annotations__ = {**arg_types, 'return': return_type}
return func

return decorator


# Test integer annotations
@pytest.mark.parametrize(("signed", "width"), [
(signed, width) for signed in [False, True]\
for width in [8, 16, 32, 64]
] + [(False, 1)]
)
def test_int_annotation(signed, width, device):

@triton.jit
@annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}")
def _kernel(X, v):
tl.store(X, v)

h = _kernel[(1, )](torch.empty(1, device=device), 3)
pfx = 'si' if signed else 'ui'
assert f'%arg1: i{width}' in h.asm["ttir"]
assert f'arith.{pfx}tofp' in h.asm["ttir"]


def test_annotations(device):
# Test that unknown annotations do not emit an error
def test_unknown_annotation(device):

@triton.jit
def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):
Expand Down
1 change: 1 addition & 0 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,7 @@ def str_to_ty(name):
"i16": language.int16,
"i32": language.int32,
"i64": language.int64,
"u1": language.int1,
"u8": language.uint8,
"u16": language.uint16,
"u32": language.uint32,
Expand Down
15 changes: 8 additions & 7 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,13 @@ def signature_key(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
elif annotation == "bool":
return "i1"
elif annotation == "float":
return "fp32"
else:
return JITFunction._key_of(self.value)
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
width = annotation[annotation.find(ty1) + len(ty1):]
if width and ty1 in annotation:
return f"{ty2}{width}"
if annotation == "bool":
return "u1"
return JITFunction._key_of(self.value)

def specialization_key(self):
assert not self.param.do_not_specialize
Expand Down Expand Up @@ -375,7 +376,7 @@ def run(self, *args, grid, warmup, **kwargs):

# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value))
arg.param.num: self._type_of(arg.signature_key())
for arg in args
if not arg.param.is_constexpr
}
Expand Down
24 changes: 10 additions & 14 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def ty_to_cpp(ty):
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u1": "uint32_t",
"u8": "uint8_t",
"u16": "uint16_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
Expand All @@ -115,29 +118,22 @@ def make_launcher(constants, signature, ids):
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
}[ty]
return ty_to_cpp(ty)

def format_of(ty):
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"uint32_t": "I",
"int8_t": "b",
"int16_t": "h",
"int32_t": "i",
"int64_t": "l",
"uint8_t": "B",
"uint16_t": "H",
"uint32_t": "I",
"uint64_t": "K",
"int64_t": "L",
}[ty]

format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
Expand Down

0 comments on commit b6e24b6

Please sign in to comment.