Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Easy][FRONTEND] Add pre run hooks to JITFunction #3314

Merged
merged 1 commit into from
Mar 7, 2024

Conversation

oulgen
Copy link
Contributor

@oulgen oulgen commented Mar 7, 2024

This PR adds pre run hooks to JITFunction in order for user to do analysis/error detection or any other pre hook operation they would like to do prior to execution of the kernel.

For PyTorch, the specific use case we have is being able to provide better error messages when a FakeTensor (our tracing tensors) are passed into the a Triton Kernel.

In addition our use case, these pre run hooks should provide generic functionality to users.

@oulgen oulgen requested a review from ptillet as a code owner March 7, 2024 19:09
Copy link
Collaborator

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

@Chillee
Copy link
Contributor

Chillee commented Mar 7, 2024

I think it would be nice if the interpreter was modified to use these hooks too :) (doesn't need to be done in this PR)

For example, something that we can't do today is

run kernel in interpreter
run kernel normally
compare results

@oulgen
Copy link
Contributor Author

oulgen commented Mar 7, 2024

@jlebar I don't have merge/write permissions, can you click on the auto-merge?

@Chillee Lets do that as follow up

@jlebar jlebar merged commit 5a7bf72 into triton-lang:main Mar 7, 2024
4 checks passed
@jlebar
Copy link
Collaborator

jlebar commented Mar 7, 2024

Merged for you!

htyu pushed a commit to htyu/triton that referenced this pull request Mar 20, 2024
This PR adds pre run hooks to JITFunction in order for user to do
analysis/error detection or any other pre hook operation they would like
to do prior to execution of the kernel.

For PyTorch, the specific use case we have is being able to provide
better error messages when a FakeTensor (our tracing tensors) are passed
into the a Triton Kernel.

In addition our use case, these pre run hooks should provide generic
functionality to users.
binarman pushed a commit to binarman/triton that referenced this pull request Apr 2, 2024
This PR adds pre run hooks to JITFunction in order for user to do
analysis/error detection or any other pre hook operation they would like
to do prior to execution of the kernel.

For PyTorch, the specific use case we have is being able to provide
better error messages when a FakeTensor (our tracing tensors) are passed
into the a Triton Kernel.

In addition our use case, these pre run hooks should provide generic
functionality to users.
zou3519 added a commit to zou3519/triton that referenced this pull request Jul 17, 2024
Follow-up to triton-lang#3314

The context is that the PyTorch team wants to be able to override the
behavior of triton kernels under various contexts. @chilee discussed this
ask with triton folks previously and they seemed amenable to it.

Some examples include:
- PyTorch has FakeTensors, tensor subclasses that don't have storage.
  These can be sent through models, but today they error out when passed
  to a triton kernel.
- PyTorch has symbolic tracing (torch.fx). torch.fx replaces Tensors
  with Proxy objects that then get passed to a function that may contain
  a triton kernel. Today that errors out.

Test Plan:
- added new tests
zou3519 added a commit to zou3519/triton that referenced this pull request Jul 17, 2024
Follow-up to triton-lang#3314

The context is that the PyTorch team wants to be able to override the
behavior of triton kernels under various contexts. @Chillee discussed this
ask with triton folks previously and they seemed amenable to it.

Some examples include:
- PyTorch has FakeTensors, tensor subclasses that don't have storage.
  These can be sent through models, but today they error out when passed
  to a triton kernel.
- PyTorch has symbolic tracing (torch.fx). torch.fx replaces Tensors
  with Proxy objects that then get passed to a function that may contain
  a triton kernel. Today that errors out.

Test Plan:
- added new tests
@zou3519 zou3519 mentioned this pull request Jul 17, 2024
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants