-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[frontend] Improve function def parsing with stacked decorators. #3564
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well this is embarrassing. Thank you for the patch.
python/triton/runtime/jit.py
Outdated
@@ -478,7 +479,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin | |||
|
|||
# function source code (without decorators) | |||
self.src = textwrap.dedent(inspect.getsource(fn)) | |||
self.src = self.src[self.src.find("def"):] | |||
self.src = self.src[re.search(r"def (\w+)\s*\((.*?)\):", self.src).start():] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the parens around (.*?)
?
Also if we're not in multiline mode I think this does not work, because .
only matches on the current line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah you're right, sorry I forgot the multiline case.
Addressed comments to minimize the test. Tweak the regex a little bit so that we always match against the beginning of a line, which is simpler to reason about. |
Under certain conditions where we apply another decorator on top of triton.jit, there is a small chance that in the decorator body there's a "def" substring in it. This is problematic when JITFunction parses function def using str.find('def'). This causes an issue compiling some pytorch model when we happened to decorate a jit function with a hash string argument which has "def" in it. This diff tries a bit harder to locate the real "def" position using a more complex regex which should further reduce the chance to hit this error.
Head branch was pushed to by a user without write access
Re-pushed this to fix yapf lint error. |
By patching runtime/jit.py with the contents of triton-lang/triton#3564 Should fix intermittent ``` SyntaxError: unterminated string literal (detected at line 1) ``` for good TODO: Delete `scripts/patch_trition.py` as well as changes to `install_requirements.sh` for good
By patching runtime/jit.py with the contents of triton-lang/triton#3564 Should fix intermittent ``` SyntaxError: unterminated string literal (detected at line 1) ``` for good TODO: Delete `scripts/patch_trition.py` as well as changes to `install_requirements.sh` for good
By patching runtime/jit.py with the contents of triton-lang/triton#3564 Should fix intermittent ``` SyntaxError: unterminated string literal (detected at line 1) ``` for good TODO: Delete `scripts/patch_trition.py` as well as changes to `install_requirements.sh` for good
By patching runtime/jit.py with the contents of triton-lang/triton#3564 Should fix intermittent ``` SyntaxError: unterminated string literal (detected at line 1) ``` for good TODO: Delete `scripts/patch_trition.py` as well as changes to `install_requirements.sh` for good
By patching runtime/jit.py with the contents of triton-lang/triton#3564 Should fix intermittent ``` SyntaxError: unterminated string literal (detected at line 1) ``` for good TODO: Delete `scripts/patch_trition.py` as well as changes to `install_requirements.sh` for good
By patching runtime/jit.py with the contents of triton-lang/triton#3564 Should fix intermittent ``` SyntaxError: unterminated string literal (detected at line 1) ``` for good TODO: Delete `scripts/patch_trition.py` as well as changes to `install_requirements.sh` for good
By patching runtime/jit.py with the contents of triton-lang/triton#3564 Should fix intermittent ``` SyntaxError: unterminated string literal (detected at line 1) ``` for good TODO: Delete `scripts/patch_trition.py` as well as changes to `install_requirements.sh` for good
Under certain conditions where we apply another decorator on top of triton.jit, there is a small chance that in the decorator body there's a "def" substring in it. This is problematic when JITFunction parses function def using str.find('def'). This causes an issue compiling some pytorch model when we happened to decorate a jit function with a hash string argument which has "def" in it.
This diff tries a bit harder to locate the real "def" position using a more complex regex which should further reduce the chance to hit this error.