Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new Triton runtime #1338

Merged
merged 26 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
888678f
Use new Triton runtime
jansel Sep 24, 2022
9dc26e4
Don't require Triton for CPU backend
jansel Sep 25, 2022
69e0421
Review changes
jansel Sep 25, 2022
a1d78ee
Merge branch 'main' of github.com:pytorch/torchdynamo into newruntime…
jansel Sep 25, 2022
9e34b19
comment
jansel Sep 25, 2022
307d98e
Skip tests failing on master
jansel Sep 25, 2022
7e38b79
unskip
jansel Sep 25, 2022
12cc6e8
Merge branch 'main' of github.com:pytorch/torchdynamo into newruntime…
jansel Sep 25, 2022
86dd713
Bump CI pin
jansel Sep 26, 2022
4d8c87c
Merge branch 'main' of github.com:pytorch/torchdynamo into newruntime…
jansel Sep 26, 2022
c2d9b55
dev20220925
jansel Sep 26, 2022
cd7ef4f
rm torchaudio
jansel Sep 26, 2022
5717afb
dev20220926
jansel Sep 26, 2022
b1d57b2
Revert to dev20220921
jansel Sep 26, 2022
ea4868e
Increase timeouts
jansel Sep 26, 2022
1c81cc0
Apply fix for #1362
jansel Sep 27, 2022
24e5fe9
Show stack traces
jansel Sep 27, 2022
8534d1c
lint
jansel Sep 27, 2022
58d9023
Skip swin_base_patch4_window7_224
jansel Sep 27, 2022
cbd938e
bernoulli
jansel Sep 27, 2022
1a3886e
Merge branch 'main' of github.com:pytorch/torchdynamo into newruntime…
jansel Sep 27, 2022
c8252f5
Revert "bernoulli"
jansel Sep 27, 2022
6184f20
Merge branch 'main' of github.com:pytorch/torchdynamo into newruntime…
jansel Sep 28, 2022
9470117
revert
jansel Sep 28, 2022
7fdf663
Merge branch 'main' of github.com:pytorch/torchdynamo into newruntime…
jansel Sep 28, 2022
f704f1f
Remove skip
jansel Sep 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor inference run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -425,6 +426,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor inference run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -447,6 +449,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor inference run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -469,6 +472,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor inference run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -491,6 +495,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor training run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -513,6 +518,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor training run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -535,6 +541,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor training run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -557,6 +564,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor training run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -580,6 +588,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor training run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand All @@ -602,6 +611,7 @@ jobs:
- install_deps
- run:
name: TIMM TorchInductor training run
no_output_timeout: 30m
command: |
source .circleci/setup_env.sh
make develop
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ PIP ?= python -m pip

# versions used in CI
PYTORCH_VERSION ?= dev20220928
TRITON_VERSION ?= 889d9e34a114b1fe2e8871d21e713794344d12d3
TRITON_VERSION ?= 998fd5f9afe166247f441999c605dfe624ca9331


default: develop
Expand Down
4 changes: 3 additions & 1 deletion test/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import importlib
import random
import sys
import unittest
from unittest.mock import patch

Expand Down Expand Up @@ -37,7 +38,8 @@
assert get_decompositions([torch.ops.aten.trace])
# Requires functorch
from torchinductor.compile_fx import compile_fx_inner
except (ImportError, ModuleNotFoundError, AssertionError):
except (ImportError, ModuleNotFoundError, AssertionError) as e:
sys.stderr.write(f"{type(e)}: {e}\n")
raise unittest.SkipTest("requires sympy/functorch")


Expand Down
58 changes: 55 additions & 3 deletions torchinductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import sysconfig
import tempfile
import types
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from ctypes import cdll
from typing import Any
from typing import Dict

from torch.utils import cpp_extension

Expand Down Expand Up @@ -160,9 +164,10 @@ def load(cls, source_code):
code = compile(f.read(), path, "exec")
mod = types.ModuleType(f"{__name__}.{key}")
mod.__file__ = path
mod.key = key
exec(code, mod.__dict__, mod.__dict__)
cls.cache[key] = mod
cls.cache[key].key = key
# another thread might set this first
cls.cache.setdefault(key, mod)
return cls.cache[key]


Expand All @@ -174,7 +179,54 @@ def patch_triton_dir():


class TritonCodeCache:
@staticmethod
def get_name(mod):
(name,) = [n for n in dir(mod) if n.startswith("kernel")]
return name

@classmethod
def load(cls, source_code):
patch_triton_dir()
return PyCodeCache.load(source_code)
mod = PyCodeCache.load(source_code)
return getattr(mod, cls.get_name(mod))


class AsyncCompile:
@staticmethod
@functools.lru_cache(1)
def pool():
assert config.compile_threads > 1
return ThreadPoolExecutor(config.compile_threads)

@classmethod
def submit(cls, task):
if config.compile_threads <= 1:
return task()
return cls.pool().submit(task)
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def map(cls, fn, seq):
if config.compile_threads <= 1 or len(seq) <= 1:
return list(map(fn, seq))
return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]

def triton(self, source_code):
kernel = TritonCodeCache.load(source_code)

def task():
kernel.precompile()
return kernel

return self.submit(task)

def cpp(self, source_code):
def task():
return CppCodeCache.load(source_code).kernel
Comment on lines +214 to +224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache load happens at subtly different times between these two. Triton's loads in the call to triton(), and the cpp one loads on the thread pool as it gets dispatched on a task. The incongruent behavior may lead to surprises later, or force CppCodeCache.load to be thread safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The C++ cache load calls gcc, which is expensive (and also inherently thread safe).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting issue with potential relevance: #1347


return self.submit(task)

def wait(self, scope: Dict[str, Any]):
if config.compile_threads > 1:
for key, result in list(scope.items()):
if isinstance(result, Future):
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
scope[key] = result.result()
14 changes: 13 additions & 1 deletion torchinductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import textwrap
import typing
from collections import namedtuple
from io import StringIO
from itertools import chain

Expand All @@ -22,6 +23,9 @@

log = logging.getLogger(__name__)

TensorArg = namedtuple("TensorArg", ["name", "dtype"])
SizeArg = namedtuple("SizeArg", ["name", "expr"])


def index_prevent_reordering(index: typing.List[sympy.Expr], index_vars, sizes):
from ..ir import FlexibleLayout
Expand Down Expand Up @@ -358,20 +362,28 @@ def cpp_argdefs(self):
def python_argdefs(self):
arg_defs = []
call_args = []
precompile_args = []
for inplaced in unique(self.inplace_buffers.values()):
arg_defs.append(inplaced.inner_name)
call_args.append(inplaced.other_names[-1])
precompile_args.append(
TensorArg(
inplaced.inner_name, V.graph.get_dtype(inplaced.other_names[-1])
)
)
for outer, inner in chain(
self.input_buffers.items(), self.output_buffers.items()
):
if outer in self.inplace_buffers or inner == "REMOVED":
continue
arg_defs.append(inner)
call_args.append(outer)
precompile_args.append(TensorArg(inner, V.graph.get_dtype(outer)))
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
call_args.append(outer)
return arg_defs, call_args
precompile_args.append(SizeArg(inner, sympy.expand(outer)))
return arg_defs, call_args, precompile_args

def aliases(self):
for inplaced in unique(self.inplace_buffers.values()):
Expand Down
4 changes: 2 additions & 2 deletions torchinductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,9 @@ def codegen_define_and_call(self, wrapper):
code.splice(self.loops_code)

codecache_def = IndentedBuffer()
codecache_def.writeline("CppCodeCache.load('''")
codecache_def.writeline("async_compile.cpp('''")
codecache_def.splice(code)
codecache_def.writeline("''').kernel")
codecache_def.writeline("''')")

kernel_name = wrapper.next_kernel_name()
codecache_str = codecache_def.getvalue()
Expand Down
Loading