-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][Transform] Preserve symbolic variables in FuseOps #16637
Merged
Lunderberg
merged 1 commit into
apache:main
from
Lunderberg:transform_preserve_symbolic_vars_in_fuse_ops
Feb 29, 2024
Merged
[Bugfix][Transform] Preserve symbolic variables in FuseOps #16637
Lunderberg
merged 1 commit into
apache:main
from
Lunderberg:transform_preserve_symbolic_vars_in_fuse_ops
Feb 29, 2024
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Prior to this commit, the `CompositeFunctionAnnotator` visited the body of functions without the parameters being considered in-scope. As a result, `EraseToWellDefined` would remove known shapes from the function body's `StructInfo`.
Lunderberg
changed the title
[Unity][Transform] Preserve symbolic variables in FuseOps
[Bugfix][Transform] Preserve symbolic variables in FuseOps
Feb 23, 2024
This PR was originally part of #16450. As some of the later changes in that PR required further discussion and potential refactoring, the separate PRs allow the bugfix in this commit to land separately. |
slyubomirsky
approved these changes
Feb 23, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good fix to have.
Lunderberg
added a commit
to Lunderberg/tvm
that referenced
this pull request
Jun 10, 2024
This is a follow-up commit to apache#16637, which updated `relax.transform.FuseOps` to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that `relax.transform.FuseOps` produces well-formed Relax functions, these additional arguments can break some kernel implementations. This commit implements a new transform `RemoveSymbolicExpressionsInSubroutine` to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes. For example, consider the following Relax function: ```python @R.function def func( data: R.Tensor(["batch_size * seq_len", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]): batch_size = T.int64() seq_len = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights) return output ``` The `data` tensor may be used to infer `hidden_size`, but cannot be used to infer `batch_size` or `seq_len`. The `R.Shape` parameter exists solely to define `batch_size` and `seq_len`, since all symbolic variables must be defined. However, neither `batch_size` nor `seq_len` are ever used outside of the expression `batch_size * seq_len`, and the value of `batch_size * seq_len` could be inferred from the shape of the `data` tensor. This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the `dummy_arg: R.Shape` be entirely unused, so a later use of `relax.transform.RemoveUnusedParameters()` can remove the parameter altogether. ```python @R.function def func( data: R.Tensor(["data_dim0", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ): data_dim0 = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights) return output ```
Lunderberg
added a commit
to Lunderberg/tvm
that referenced
this pull request
Sep 11, 2024
This is a follow-up commit to apache#16637, which updated `relax.transform.FuseOps` to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that `relax.transform.FuseOps` produces well-formed Relax functions, these additional arguments can break some kernel implementations. This commit implements a new transform `RemoveSymbolicExpressionsInSubroutine` to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes. For example, consider the following Relax function: ```python @R.function def func( data: R.Tensor(["batch_size * seq_len", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]): batch_size = T.int64() seq_len = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights) return output ``` The `data` tensor may be used to infer `hidden_size`, but cannot be used to infer `batch_size` or `seq_len`. The `R.Shape` parameter exists solely to define `batch_size` and `seq_len`, since all symbolic variables must be defined. However, neither `batch_size` nor `seq_len` are ever used outside of the expression `batch_size * seq_len`, and the value of `batch_size * seq_len` could be inferred from the shape of the `data` tensor. This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the `dummy_arg: R.Shape` be entirely unused, so a later use of `relax.transform.RemoveUnusedParameters()` can remove the parameter altogether. ```python @R.function def func( data: R.Tensor(["data_dim0", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ): data_dim0 = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights) return output ```
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Prior to this commit, the
CompositeFunctionAnnotator
visited the body of functions without the parameters being considered in-scope. As a result,EraseToWellDefined
would remove known shapes from the function body'sStructInfo
.