-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use new Triton runtime #1338
Use new Triton runtime #1338
Conversation
kernel = TritonCodeCache.load(source_code) | ||
|
||
def task(): | ||
kernel.precompile() | ||
return kernel | ||
|
||
return self.submit(task) | ||
|
||
def cpp(self, source_code): | ||
def task(): | ||
return CppCodeCache.load(source_code).kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cache load happens at subtly different times between these two. Triton's loads in the call to triton()
, and the cpp one loads on the thread pool as it gets dispatched on a task. The incongruent behavior may lead to surprises later, or force CppCodeCache.load to be thread safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The C++ cache load calls gcc, which is expensive (and also inherently thread safe).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting issue with potential relevance: #1347
Simplified version of Triton autotuner that has no invalidation | ||
key and caches the best config to disk to improve cold start times. | ||
Unlike the main triton Autotuner, this version can precompile all | ||
configs, and does not rely on the Triton JIT. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome
test/test_torchinductor_opinfo.py
Outdated
@@ -91,7 +91,7 @@ def process(device_type): | |||
"mvlgamma.mvlgamma_p_5": {f32, f64, i32, i64}, # flaky | |||
"cumprod": {f32, f64}, # flaky | |||
"_masked.prod": {f32, f64}, # flaky | |||
"empty_like": {b8, f16, f32, f64}, # flaky | |||
"empty_like": {b8, f16, f32, f64, i32, i64}, # flaky |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol we shouldn't be checking results of empty_like, they are undefined, but that's for another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is flakey on master.
divisible_by_16 = [ | ||
i | ||
for i, arg in enumerate(args) | ||
if isinstance(arg, TensorArg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit concerned by saying all tensor args are divisible by 16, e.g. if there are some slicing op inside the compiled graph, and we end up calling codegen_reference
with an odd offset
torchdynamo/torchinductor/ir.py
Line 1336 in dedca39
return f"as_strided({self.get_name()}, {size}, {stride}, {offset})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
codegen_reference()
is only used for calling extern kernels (and return values) I believe. For Triton kernels we apply the offset inside the indexing formula.
return align_inputs(compiled_fn, example_inputs, range(num_fixed)) | ||
|
||
|
||
def clone_preserve_strides(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if inputs are slices of the same tensor, this blows up the memory, and also removes aliasing but that's pre-existing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You think we should switch to recompiling if alignments change? I did this based on your suggestion when we talked in person.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk, maybe we should handle arg aliasing some other way? It's problematic now anyways (and we are already doing this for cuda graphs) so it's not a regression and definitely not blocking for this PR.
LGTM from first principles. |
This PR now has a single test failure:
At first, I thought this was related to the alignment code in the new Triton runtime, however, if I disable that with: diff --git a/torchinductor/codegen/triton.py b/torchinductor/codegen/triton.py
index 51123ce..439dfc4 100644
--- a/torchinductor/codegen/triton.py
+++ b/torchinductor/codegen/triton.py
@@ -50,10 +50,6 @@ def config_of(args):
from triton.runtime.jit import JITFunction
divisible_by_16 = [
- i
- for i, arg in enumerate(args)
- if isinstance(arg, TensorArg)
- or V.graph.sizevars.maybe_guard_multiple_of(arg.expr, JITFunction.divisibility)
]
return instance_descriptor(tuple(divisible_by_16), ())
diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py
index 56c7d5c..efb6463 100644
--- a/torchinductor/compile_fx.py
+++ b/torchinductor/compile_fx.py
@@ -106,7 +106,7 @@ def compile_fx_inner(
elif complex_memory_overlap_inputs:
log.warning("skipping cudagraphs due to complex input striding")
- return align_inputs(compiled_fn, example_inputs, range(num_fixed))
+ return compiled_fn
def clone_preserve_strides(x):
diff --git a/torchinductor/config.py b/torchinductor/config.py
index 7d9c38e..fd336b3 100644
--- a/torchinductor/config.py
+++ b/torchinductor/config.py
@@ -73,7 +73,7 @@ class cpp:
class triton:
# Use cudagraphs on output code
- cudagraphs = True
+ cudagraphs = False
# choose conv backend, "aten" or "triton" or "autotune"
convolution = "aten" It still gets an illegal memory access. I tried setting
and
and
It feels almost like there is some CUDA corruption causing random failures everywhere. I am stumped... @ngimel any ideas here? I think you mentioned seeing similar stuff elsewhere. I'm not sure if this is related to this PR. Perhaps we should just skip this test and turn it into an issue. |
@voznesenskym was also mentioning today that (without this PR), running OpInfo tests on |
This reverts commit cbd938e.
@ptillet recently rewrote the Triton runtime in triton-lang/triton#644
This should dramatically reduce CPU overheads when cudagraphs is disabled.
This updates TorchInductor to use that new runtime. We now call
triton.compile()
, and no longer usetriton.jit()
.There is also some early support for parallel compiles, but still need optimize that part.
Note this PR breaks support for Triton versions prior to 998fd5f9afe166247f441999c605dfe624ca9331.