Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new Triton runtime #1338

Merged
merged 26 commits into from
Sep 28, 2022
Merged

Use new Triton runtime #1338

merged 26 commits into from
Sep 28, 2022

Conversation

jansel
Copy link
Contributor

@jansel jansel commented Sep 24, 2022

@ptillet recently rewrote the Triton runtime in triton-lang/triton#644

This should dramatically reduce CPU overheads when cudagraphs is disabled.

This updates TorchInductor to use that new runtime. We now call triton.compile(), and no longer use triton.jit().

There is also some early support for parallel compiles, but still need optimize that part.

Note this PR breaks support for Triton versions prior to 998fd5f9afe166247f441999c605dfe624ca9331.

Comment on lines +214 to +224
kernel = TritonCodeCache.load(source_code)

def task():
kernel.precompile()
return kernel

return self.submit(task)

def cpp(self, source_code):
def task():
return CppCodeCache.load(source_code).kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache load happens at subtly different times between these two. Triton's loads in the call to triton(), and the cpp one loads on the thread pool as it gets dispatched on a task. The incongruent behavior may lead to surprises later, or force CppCodeCache.load to be thread safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The C++ cache load calls gcc, which is expensive (and also inherently thread safe).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting issue with potential relevance: #1347

Comment on lines +37 to +40
Simplified version of Triton autotuner that has no invalidation
key and caches the best config to disk to improve cold start times.
Unlike the main triton Autotuner, this version can precompile all
configs, and does not rely on the Triton JIT.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome

@@ -91,7 +91,7 @@ def process(device_type):
"mvlgamma.mvlgamma_p_5": {f32, f64, i32, i64}, # flaky
"cumprod": {f32, f64}, # flaky
"_masked.prod": {f32, f64}, # flaky
"empty_like": {b8, f16, f32, f64}, # flaky
"empty_like": {b8, f16, f32, f64, i32, i64}, # flaky
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol we shouldn't be checking results of empty_like, they are undefined, but that's for another PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is flakey on master.

divisible_by_16 = [
i
for i, arg in enumerate(args)
if isinstance(arg, TensorArg)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit concerned by saying all tensor args are divisible by 16, e.g. if there are some slicing op inside the compiled graph, and we end up calling codegen_reference with an odd offset

return f"as_strided({self.get_name()}, {size}, {stride}, {offset})"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codegen_reference() is only used for calling extern kernels (and return values) I believe. For Triton kernels we apply the offset inside the indexing formula.

return align_inputs(compiled_fn, example_inputs, range(num_fixed))


def clone_preserve_strides(x):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if inputs are slices of the same tensor, this blows up the memory, and also removes aliasing but that's pre-existing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You think we should switch to recompiling if alignments change? I did this based on your suggestion when we talked in person.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk, maybe we should handle arg aliasing some other way? It's problematic now anyways (and we are already doing this for cuda graphs) so it's not a regression and definitely not blocking for this PR.

torchinductor/triton_ops/autotune.py Show resolved Hide resolved
@voznesenskym
Copy link
Contributor

LGTM from first principles.

@jansel
Copy link
Contributor Author

jansel commented Sep 27, 2022

This PR now has a single test failure:

CUDA_LAUNCH_BLOCKING=1 python benchmarks/timm_models.py --accuracy -d cuda --inductor --float32 --training --only swin_base_patch4_window7_224
cuda train swin_base_patch4_window7_224
...
  File "/tmp/torchinductor_jansel/nd/cndxz5mmbqajkiakvbzxdczaralpd7t57ajewnph4cqbiabmemjf.py", line 6011, in call
    kernel76.run(buf1055, _unsafe_view_12, getitem_19, reciprocal_8, buf1061, buf1063, 3328, 121, grid=grid(3328), stream=stream0)
  File "/home/jansel/torchdynamo/torchinductor/triton_ops/autotune.py", line 157, in run
    return launcher(
  File "<string>", line 4, in launcher
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

At first, I thought this was related to the alignment code in the new Triton runtime, however, if I disable that with:

diff --git a/torchinductor/codegen/triton.py b/torchinductor/codegen/triton.py
index 51123ce..439dfc4 100644
--- a/torchinductor/codegen/triton.py
+++ b/torchinductor/codegen/triton.py
@@ -50,10 +50,6 @@ def config_of(args):
     from triton.runtime.jit import JITFunction
 
     divisible_by_16 = [
-        i
-        for i, arg in enumerate(args)
-        if isinstance(arg, TensorArg)
-        or V.graph.sizevars.maybe_guard_multiple_of(arg.expr, JITFunction.divisibility)
     ]
     return instance_descriptor(tuple(divisible_by_16), ())
 
diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py
index 56c7d5c..efb6463 100644
--- a/torchinductor/compile_fx.py
+++ b/torchinductor/compile_fx.py
@@ -106,7 +106,7 @@ def compile_fx_inner(
             elif complex_memory_overlap_inputs:
                 log.warning("skipping cudagraphs due to complex input striding")
 
-    return align_inputs(compiled_fn, example_inputs, range(num_fixed))
+    return compiled_fn
 
 
 def clone_preserve_strides(x):
diff --git a/torchinductor/config.py b/torchinductor/config.py
index 7d9c38e..fd336b3 100644
--- a/torchinductor/config.py
+++ b/torchinductor/config.py
@@ -73,7 +73,7 @@ class cpp:
 class triton:
 
     # Use cudagraphs on output code
-    cudagraphs = True
+    cudagraphs = False
 
     # choose conv backend, "aten" or "triton" or "autotune"
     convolution = "aten"

It still gets an illegal memory access.

I tried setting TORCHDYNAMO_REPRO_AFTER = aot/dynamo, but both of them fail with IMAs inside the repro code:

  File "/home/jansel/torchdynamo/torchdynamo/debug_utils.py", line 501, in dump_backend_repro_as_tarfile
    gm.to_folder(gm_dir, "Repro")
  File "/home/jansel/conda/envs/torchdynamo/lib/python3.8/site-packages/torch/fx/graph_module.py", line 423, in to_folder
    torch.save(self.state_dict(), folder / 'state_dict.pt')
  File "/home/jansel/conda/envs/torchdynamo/lib/python3.8/site-packages/torch/serialization.py", line 422, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/jansel/conda/envs/torchdynamo/lib/python3.8/site-packages/torch/serialization.py", line 646, in _save
    storage = storage.cpu()
  File "/home/jansel/conda/envs/torchdynamo/lib/python3.8/site-packages/torch/storage.py", line 120, in cpu
    return torch.UntypedStorage(self.size()).copy_(self, False)
RuntimeError: CUDA error: an illegal memory access was encountered

and

  File "/home/jansel/conda/envs/torchdynamo/lib/python3.8/site-packages/torch/_tensor_str.py", line 115, in __init__
    tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
RuntimeError: CUDA error: an illegal memory access was encountered

and

  File "/home/jansel/conda/envs/torchdynamo/lib/python3.8/site-packages/torch/cuda/random.py", line 62, in cb
    default_generator.set_state(new_state_copy)
RuntimeError: CUDA error: an illegal memory access was encountered

It feels almost like there is some CUDA corruption causing random failures everywhere. I am stumped...

@ngimel any ideas here? I think you mentioned seeing similar stuff elsewhere.

I'm not sure if this is related to this PR. Perhaps we should just skip this test and turn it into an issue.

@soumith
Copy link
Member

soumith commented Sep 27, 2022

@voznesenskym was also mentioning today that (without this PR), running OpInfo tests on main, but reordering the tests is causing them to fail in strange ways including IMEs. Voz said he'll start looking into that under compute-sanitizer

@jansel
Copy link
Contributor Author

jansel commented Sep 27, 2022

We see this same error in #1362 so I am going to skip that test and spin this off into a new issue: #1365

@desertfire desertfire mentioned this pull request Sep 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants