-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FRONTEND] Fix arg name conflict bug #3383
Conversation
@Sarbojit2019 I found python AST generator has bug in estimating SRAM used and tensor assignment, did you noticed that before ? |
Taking a look now (the example itself seems buggy). |
@@ -36,6 +36,9 @@ def ret(self): | |||
return self.hasher.hexdigest() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Sarbojit2019 Please correct your reproducer so that we can use it without any changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Jokeren, could you please help me understand the issue you see with the reproducer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot run your reproducer directly. There are at least two problems IIRC. For example, e_src
is not defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a quick look, another one should be this:
num_rows, num_indices, heads = 16, 16, 16, 4
python/triton/runtime/jit.py
Outdated
@@ -77,6 +80,10 @@ def is_triton_builtin(func): | |||
key = func_cache_key + noinline | |||
self.hasher.update(key.encode("utf-8")) | |||
|
|||
def visit_FunctionDef(self, node): | |||
# Save the local name which may hide the global name. | |||
self.local_name = [arg.arg for arg in node.args.args] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is incomplete because we can also define a local variable with assignment, not as function arguments.
For example:
@triton.jit
def kernel(...):
ptr = tl.load(...)
print(ptr.dtype.element_ty)
ptr = torch.tensor(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Sarbojit2019 Do you plan to address the issue in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, proposed the fix in latest commit.
Thanks for proposing the fix. This is a known problem previously reported by PT2 developers. We were sticking to the "implicit rule" that names cannot be same in the local and the global scopes. But I think it's a good try to fix it now. |
@Jokeren, I have updated the sample that I had pasted earlier.
|
Hi @Sarbojit2019 , I had a deeper thought. The naming conflict won't be fully resolved until all possible definitions in AST visitor has been covered. For example, the iteration variable of a for loop could also shadow a global variable. So I think we need to have a thorough test under test/runtime to cover all scenarios. It could be done in the next PR. Would you like to contribute to this feature? |
@Jokeren, I would be happy to work on it. |
It would be a simpler version of |
There is a bug introduced in latest triton code where AttributeError is getting thrown if kernel function args has same name as global variable. Co-authored-by: Keren Zhou <kerenzhou@openai.com>
There is a bug introduced in latest triton code where AttributeError is getting thrown if kernel function args has same name as global variable. See the below example for more clarity
Error it throws something like below
return getattr(lhs, node.attr) ^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'torch.dtype' object has no attribute 'element_ty'
With the proposed change I see test is passing. I have verified it on Nvidia and Intel platforms.