Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] Fix arg name conflict bug #3383

Merged
merged 8 commits into from
Mar 22, 2024

Conversation

Sarbojit2019
Copy link
Contributor

There is a bug introduced in latest triton code where AttributeError is getting thrown if kernel function args has same name as global variable. See the below example for more clarity

import torch
import triton
import triton.language as tl

@triton.jit
def kernel(
      out, adj, num_rows,
      WG_SIZE : tl.constexpr
  ):
    zero_indices = tl.zeros((WG_SIZE,), dtype=adj.dtype.element_ty)


if __name__ == '__main__':
    device = torch.device('cuda')

    num_rows, num_indices, heads = 16, 16, 16, 4
    adj = torch.zeros(num_rows + 1, dtype=torch.int32, device=device)
    out = torch.empty((num_indices,heads), dtype=e_src.dtype, device=e_src.device)
    # invoke triton kernel
    launch_grid = (num_rows * heads // 256,)
    kernel[launch_grid](out, adj, num_rows, 256) 

Error it throws something like below
return getattr(lhs, node.attr) ^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'torch.dtype' object has no attribute 'element_ty'

With the proposed change I see test is passing. I have verified it on Nvidia and Intel platforms.

@Sarbojit2019 Sarbojit2019 requested a review from ptillet as a code owner March 15, 2024 05:26
@ThomasRaoux ThomasRaoux requested a review from Jokeren March 15, 2024 05:28
@yiakwy-xpu-ml-framework-team

@Sarbojit2019 I found python AST generator has bug in estimating SRAM used and tensor assignment, did you noticed that before ?

@Jokeren
Copy link
Contributor

Jokeren commented Mar 15, 2024

Taking a look now (the example itself seems buggy).

python/triton/runtime/jit.py Outdated Show resolved Hide resolved
python/triton/runtime/jit.py Outdated Show resolved Hide resolved
@@ -36,6 +36,9 @@ def ret(self):
return self.hasher.hexdigest()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sarbojit2019 Please correct your reproducer so that we can use it without any changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jokeren, could you please help me understand the issue you see with the reproducer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot run your reproducer directly. There are at least two problems IIRC. For example, e_src is not defined.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a quick look, another one should be this:

    num_rows, num_indices, heads = 16, 16, 16, 4

@@ -77,6 +80,10 @@ def is_triton_builtin(func):
key = func_cache_key + noinline
self.hasher.update(key.encode("utf-8"))

def visit_FunctionDef(self, node):
# Save the local name which may hide the global name.
self.local_name = [arg.arg for arg in node.args.args]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is incomplete because we can also define a local variable with assignment, not as function arguments.

For example:

@triton.jit
def kernel(...):
    ptr = tl.load(...)
    print(ptr.dtype.element_ty)

ptr = torch.tensor(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sarbojit2019 Do you plan to address the issue in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, proposed the fix in latest commit.

@Jokeren
Copy link
Contributor

Jokeren commented Mar 15, 2024

Thanks for proposing the fix.

This is a known problem previously reported by PT2 developers. We were sticking to the "implicit rule" that names cannot be same in the local and the global scopes. But I think it's a good try to fix it now.

@Jokeren Jokeren changed the title Fix for function args name conflict bug [FRONTEND] Fix for arg name conflict bug Mar 15, 2024
@Jokeren Jokeren changed the title [FRONTEND] Fix for arg name conflict bug [FRONTEND] Fix arg name conflict bug Mar 15, 2024
@Sarbojit2019
Copy link
Contributor Author

Sarbojit2019 commented Mar 20, 2024

@Jokeren, I have updated the sample that I had pasted earlier.

import torch
import triton
import triton.language as tl

@triton.jit
def kernel(
      out, adj, num_rows,
      WG_SIZE : tl.constexpr
  ):
    zero_indices = tl.zeros((WG_SIZE,), dtype=adj.dtype.element_ty)


if __name__ == '__main__':
    device = torch.device('cuda')

    num_rows, num_cols, num_indices, heads = 16, 16, 16, 4
    e_src = torch.randn((num_cols,heads), dtype=torch.float32, device=device) + 0.5
    adj = torch.zeros(num_rows + 1, dtype=torch.int32, device=device)
    out = torch.empty((num_indices,heads), dtype=e_src.dtype, device=e_src.device)
    # invoke triton kernel
    launch_grid = (num_rows * heads // 256,)
    kernel[launch_grid](out, adj, num_rows, 256)

python/triton/runtime/jit.py Outdated Show resolved Hide resolved
@Jokeren
Copy link
Contributor

Jokeren commented Mar 22, 2024

Hi @Sarbojit2019 , I had a deeper thought.

The naming conflict won't be fully resolved until all possible definitions in AST visitor has been covered. For example, the iteration variable of a for loop could also shadow a global variable.

So I think we need to have a thorough test under test/runtime to cover all scenarios. It could be done in the next PR. Would you like to contribute to this feature?

@Jokeren Jokeren merged commit dca2d07 into triton-lang:main Mar 22, 2024
5 checks passed
@Sarbojit2019
Copy link
Contributor Author

@Jokeren, I would be happy to work on it.
As I understood triton supports smaller set of operations compare to python e.g. simultaneous assignments is not supported in triton. Could you please let me know where can I find the grammar for triton? This will help fix possible scenarios better.

@Jokeren
Copy link
Contributor

Jokeren commented Mar 23, 2024

It would be a simpler version of CodeGenerator, you only need to care about where new variables are defined.

karupayun pushed a commit to openxla/triton that referenced this pull request Apr 3, 2024
There is a bug introduced in latest triton code where AttributeError is
getting thrown if kernel function args has same name as global variable.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants