Skip to content

Commit

Permalink
[FRONTEND] refactor compiler submodule (triton-lang#2701)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Nov 30, 2023
1 parent 1bce042 commit b5a9a63
Show file tree
Hide file tree
Showing 19 changed files with 526 additions and 760 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ jobs:
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
fi
Integration-Tests:
needs: Runner-Preparation

Expand All @@ -49,7 +48,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: 'true'
submodules: "true"
- name: Set CUDA ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
Expand Down
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def build_extension(self, ext):
"triton/_C",
"triton/common",
"triton/compiler",
"triton/compiler/backends",
"triton/language",
"triton/language/extra",
"triton/ops",
Expand Down
12 changes: 6 additions & 6 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1685,9 +1685,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::triton::createReorderBroadcastPass());
})
.def("add_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::triton::createRewriteTensorPointerPass(capability));
})
.def("add_tritongpu_ws_feasibility_checking_pass",
[](mlir::PassManager &self, int computeCapability) {
Expand Down Expand Up @@ -1761,9 +1761,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
})
.def("add_tritongpu_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::createTritonGPURewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::createTritonGPURewriteTensorPointerPass(capability));
})
.def("add_tritongpu_decompose_conversions_pass",
[](mlir::PassManager &self) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -900,11 +900,12 @@ def process_epilogue(d, bias, w, epilogue):
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
if NUM_CTAS > 1:
device = get_current_device()
null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
null_kernel = triton.compile(src)
null_kernel._init_handles()
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS,
1, 1)
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1,
1)
num_SMs = num_clusters

def grid(META):
Expand Down
38 changes: 13 additions & 25 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import multiprocessing
import os
import shutil
from collections import namedtuple

import torch

import triton
import triton.language as tl
from triton.compiler import ASTSource

tmpdir = ".tmp"

Expand All @@ -17,32 +17,26 @@ def reset_tmp_dir():
shutil.rmtree(tmpdir, ignore_errors=True)


instance_descriptor = namedtuple("instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])


def compile_fn(config, cc):
def compile_fn(attrs, capability):

@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)

triton.compile(
src = ASTSource(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
attrs=attrs,
)
triton.compile(src=src, target=("cuda", capability))


def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), (), (), ())
config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ())

multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
Expand All @@ -51,7 +45,7 @@ def test_compile_in_subproc() -> None:
assert proc.exitcode == 0


def compile_fn_dot(config, cc):
def compile_fn_dot(attrs, capability):

@triton.jit
def kernel_dot(Z):
Expand All @@ -60,24 +54,18 @@ def kernel_dot(Z):
z = tl.dot(z, z)
tl.store(Z + offs, z)

triton.compile(
fn=kernel_dot,
signature={0: "*fp32"},
device=0,
configs=[config],
warm_cache_only=True,
cc=cc,
)
src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict())
triton.compile(src=src, target=("cuda", capability))


def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(1)), (), (), ())
capability = major * 10 + minor
config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ())

assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc))
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability))
proc.start()
proc.join()
assert proc.exitcode == 0
2 changes: 1 addition & 1 deletion python/test/unit/tools/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def test_ttgir_to_ptx():
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
with open(kernel_path, "w") as fp:
fp.write(src)
k = triton.compile(kernel_path, cc=80)
k = triton.compile(kernel_path, target=("cuda", 80))
ptx = k.asm["ptx"]
assert ".target sm_80" in ptx
assert ".address_size 64" in ptx
1 change: 1 addition & 0 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from . import language
from . import testing
from . import tools

__all__ = [
"autotune",
Expand Down
5 changes: 2 additions & 3 deletions python/triton/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps,
instance_descriptor)
from .compiler import (CompiledKernel, ASTSource, compile, AttrsDescriptor)
from .errors import CompilationError

__all__ = [
"compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
"compile", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps",
"get_arch_default_num_stages"
]
File renamed without changes.
Loading

0 comments on commit b5a9a63

Please sign in to comment.