Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] refactor compiler submodule #2701

Merged
merged 67 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
adfdb35
some cleaning
ptillet Nov 13, 2023
3d3de11
.
ptillet Nov 13, 2023
2d2579f
.
ptillet Nov 13, 2023
769c08a
.
ptillet Nov 13, 2023
a00020d
cleanup target somewhat
ptillet Nov 13, 2023
0b98e0c
more cleaning
ptillet Nov 13, 2023
e6a9f8a
simplify hash
ptillet Nov 13, 2023
2a41431
semantic analysis no longer get target
ptillet Nov 14, 2023
a414f6d
more cleaning
ptillet Nov 23, 2023
b40ae97
.
ptillet Nov 23, 2023
3d7b773
optimize_ttir no longer depends on target
ptillet Nov 23, 2023
7274e2a
rewrite_tensor_pointer no longer depends on capability
ptillet Nov 23, 2023
4b44f8c
cleaning
ptillet Nov 23, 2023
5f729eb
cleaning
ptillet Nov 23, 2023
3692c4a
cleaning
ptillet Nov 24, 2023
756512e
remove parser from stage
ptillet Nov 24, 2023
44ba22d
cleaning
ptillet Nov 24, 2023
9f62ce6
more cleaning
ptillet Nov 24, 2023
6ee5cee
more cleaning
ptillet Nov 24, 2023
c0c1076
.
ptillet Nov 24, 2023
78f9670
removed more dead code
ptillet Nov 25, 2023
5c61a54
removed override-related code. Will re-add support for `ttgir` in `tr…
ptillet Nov 25, 2023
7a1f167
more cleaning
ptillet Nov 25, 2023
478dce1
.
ptillet Nov 25, 2023
fb731e7
Merge remote-tracking branch 'origin/main' into phil/compile-refactor
ptillet Nov 25, 2023
1f64f5b
partial support for ttgir input (not finished)
ptillet Nov 25, 2023
eb61e52
fixed bugs
ptillet Nov 25, 2023
c19a8f3
.
ptillet Nov 25, 2023
c87a324
some cleaning
ptillet Nov 26, 2023
d9f8e0c
more cleaning
ptillet Nov 26, 2023
2776a83
more cleaning
ptillet Nov 26, 2023
33f0480
cleaning
ptillet Nov 26, 2023
aeeae56
more cleaning
ptillet Nov 26, 2023
9eda8ac
temporary disable TMA tests to see if the rest works
ptillet Nov 26, 2023
bb849fb
.
ptillet Nov 26, 2023
9379678
.
ptillet Nov 26, 2023
3812aed
comment out interpreter tests
ptillet Nov 26, 2023
4686fcc
.
ptillet Nov 26, 2023
f1d2820
.
ptillet Nov 26, 2023
693cfe2
.
ptillet Nov 27, 2023
6a7dc51
.
ptillet Nov 27, 2023
1e00e4a
.
ptillet Nov 27, 2023
0c69111
.
ptillet Nov 28, 2023
7ee8e06
Merge branch 'main' into phil/compile-refactor
ptillet Nov 28, 2023
e5a7470
.
ptillet Nov 28, 2023
2331078
fixup
ptillet Nov 28, 2023
d00a7b3
.
ptillet Nov 28, 2023
0b983bd
.
ptillet Nov 28, 2023
2e9e902
.
ptillet Nov 28, 2023
2563c0d
.
ptillet Nov 28, 2023
4d53e08
.
ptillet Nov 29, 2023
ca5781c
.
ptillet Nov 29, 2023
84c5598
device_type -> target
ptillet Nov 29, 2023
2651cf8
.
ptillet Nov 29, 2023
f10d336
.
ptillet Nov 29, 2023
730cecd
fixup
ptillet Nov 29, 2023
5f84af4
.
ptillet Nov 29, 2023
7f4ff23
cleaning
ptillet Nov 29, 2023
15e8e42
more cleaning
ptillet Nov 29, 2023
f0dfa23
cleaning
ptillet Nov 29, 2023
2fd04d0
fix linker option bug
ptillet Nov 29, 2023
6c8b3f0
.
ptillet Nov 30, 2023
0a4f518
fixup
ptillet Nov 30, 2023
10f3c62
fixup
ptillet Nov 30, 2023
bac1284
.
ptillet Nov 30, 2023
ceb90dd
Merge branch 'main' into phil/compile-refactor
ptillet Nov 30, 2023
c525439
fixup
ptillet Nov 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ jobs:
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
fi


Integration-Tests:
needs: Runner-Preparation

Expand All @@ -49,7 +48,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: 'true'
submodules: "true"
- name: Set CUDA ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
Expand Down
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def build_extension(self, ext):
"triton/_C",
"triton/common",
"triton/compiler",
"triton/compiler/backends",
"triton/language",
"triton/language/extra",
"triton/ops",
Expand Down
12 changes: 6 additions & 6 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1685,9 +1685,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::triton::createReorderBroadcastPass());
})
.def("add_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::triton::createRewriteTensorPointerPass(capability));
})
.def("add_tritongpu_ws_feasibility_checking_pass",
[](mlir::PassManager &self, int computeCapability) {
Expand Down Expand Up @@ -1761,9 +1761,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
})
.def("add_tritongpu_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::createTritonGPURewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::createTritonGPURewriteTensorPointerPass(capability));
})
.def("add_tritongpu_decompose_conversions_pass",
[](mlir::PassManager &self) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -900,11 +900,12 @@ def process_epilogue(d, bias, w, epilogue):
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
if NUM_CTAS > 1:
device = get_current_device()
null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
null_kernel = triton.compile(src)
null_kernel._init_handles()
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS,
1, 1)
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1,
1)
num_SMs = num_clusters

def grid(META):
Expand Down
38 changes: 13 additions & 25 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import multiprocessing
import os
import shutil
from collections import namedtuple

import torch

import triton
import triton.language as tl
from triton.compiler import ASTSource

tmpdir = ".tmp"

Expand All @@ -17,32 +17,26 @@ def reset_tmp_dir():
shutil.rmtree(tmpdir, ignore_errors=True)


instance_descriptor = namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])


def compile_fn(config, cc):
def compile_fn(attrs, capability):

@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)

triton.compile(
src = ASTSource(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
attrs=attrs,
)
triton.compile(src=src, target=("cuda", capability))


def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), (), (), ())
config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ())

multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
Expand All @@ -51,7 +45,7 @@ def test_compile_in_subproc() -> None:
assert proc.exitcode == 0


def compile_fn_dot(config, cc):
def compile_fn_dot(attrs, capability):

@triton.jit
def kernel_dot(Z):
Expand All @@ -60,24 +54,18 @@ def kernel_dot(Z):
z = tl.dot(z, z)
tl.store(Z + offs, z)

triton.compile(
fn=kernel_dot,
signature={0: "*fp32"},
device=0,
configs=[config],
warm_cache_only=True,
cc=cc,
)
src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict())
triton.compile(src=src, target=("cuda", capability))


def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(1)), (), (), ())
capability = major * 10 + minor
config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ())

assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc))
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability))
proc.start()
proc.join()
assert proc.exitcode == 0
2 changes: 1 addition & 1 deletion python/test/unit/tools/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def test_ttgir_to_ptx():
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
with open(kernel_path, "w") as fp:
fp.write(src)
k = triton.compile(kernel_path, cc=80)
k = triton.compile(kernel_path, target=("cuda", 80))
ptx = k.asm["ptx"]
assert ".target sm_80" in ptx
assert ".address_size 64" in ptx
1 change: 1 addition & 0 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from . import language
from . import testing
from . import tools

__all__ = [
"autotune",
Expand Down
5 changes: 2 additions & 3 deletions python/triton/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
instance_descriptor)
from .compiler import (CompiledKernel, ASTSource, compile, AttrsDescriptor)
from .errors import CompilationError

__all__ = [
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
"compile", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
"get_arch_default_num_stages"
]
Loading
Loading