Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] Allow tl.{u}int{width} annotations to bypass opportunistic value-based JIT-specialization #3102

Merged
merged 2 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions python/test/unit/language/test_annotations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
from __future__ import annotations

import torch

import triton
import triton.language as tl
import pytest


def annotated_function(return_type=None, **arg_types):
"""A decorator to add annotations to a function."""

def decorator(func):
func.__annotations__ = {**arg_types, 'return': return_type}
return func

return decorator


# Test integer annotations
@pytest.mark.parametrize(("signed", "width"), [
(signed, width) for signed in [False, True]\
for width in [8, 16, 32, 64]
] + [(False, 1)]
)
def test_int_annotation(signed, width, device):

@triton.jit
@annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}")
def _kernel(X, v):
tl.store(X, v)

h = _kernel[(1, )](torch.empty(1, device=device), 3)
pfx = 'si' if signed else 'ui'
assert f'%arg1: i{width}' in h.asm["ttir"]
assert f'arith.{pfx}tofp' in h.asm["ttir"]


def test_annotations(device):
# Test that unknown annotations do not emit an error
def test_unknown_annotation(device):

@triton.jit
def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):
Expand Down
1 change: 1 addition & 0 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,7 @@ def str_to_ty(name):
"i16": language.int16,
"i32": language.int32,
"i64": language.int64,
"u1": language.int1,
"u8": language.uint8,
"u16": language.uint16,
"u32": language.uint32,
Expand Down
15 changes: 8 additions & 7 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,13 @@ def signature_key(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
elif annotation == "bool":
return "i1"
elif annotation == "float":
return "fp32"
else:
return JITFunction._key_of(self.value)
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
width = annotation[annotation.find(ty1) + len(ty1):]
if width and ty1 in annotation:
return f"{ty2}{width}"
if annotation == "bool":
return "u1"
return JITFunction._key_of(self.value)

def specialization_key(self):
assert not self.param.do_not_specialize
Expand Down Expand Up @@ -375,7 +376,7 @@ def run(self, *args, grid, warmup, **kwargs):

# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value))
arg.param.num: self._type_of(arg.signature_key())
for arg in args
if not arg.param.is_constexpr
}
Expand Down
24 changes: 10 additions & 14 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def ty_to_cpp(ty):
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u1": "uint32_t",
"u8": "uint8_t",
"u16": "uint16_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
Expand All @@ -115,29 +118,22 @@ def make_launcher(constants, signature, ids):
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
}[ty]
return ty_to_cpp(ty)

def format_of(ty):
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"uint32_t": "I",
"int8_t": "b",
"int16_t": "h",
"int32_t": "i",
"int64_t": "l",
"uint8_t": "B",
"uint16_t": "H",
"uint32_t": "I",
"uint64_t": "K",
"int64_t": "L",
}[ty]

format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
Expand Down
Loading