-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Implement relax.Function.bind_symbolic_vars #15509
Conversation
If a function has dynamic shape parameters, it can be useful to replace them with static parameters (e.g. when producing several models within the same family). This commit introduces a utility function `relax.Function.bind_symbolic_vars`, which allows symbolic variables to be replaced with static values. This is a related to the parameter binding done in `relax.transform.BindParam`, but does not require the bound parameter to be fully static data array.
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Previously, `ExprBinder` only checked whether a `PrimExpr` was a symbolic variable to be replaced, but did not handle cases where a `PrimExpr` contained a symbolic variable to be replaced. As a result, when binding symbolic variables `{N: 16}`, a shape of `[N,2*N]` would be updated to `[16,2*N]` instead of `[16,32]`. This commit updates `ExprBinder` to use `tir::Substitute` to ensure all occurrences of the symbolic variable are replaced.
Thanks, @Lunderberg! We might have implemented the similar feature haha |
Ooh, thank you for pointing it out! That would explain why I didn't see the functionality anywhere in the Looking through it, I think the main differences are:
Thoughts on making a merged PR with the best of both implementations:
Any preference on which of the two PRs should be updated overall? Since both provide similar functionality, I don't have a huge preference. |
Thank you for the great summary! |
@sunggg How does the IRModule pass in the latest commit look? I pulled in the unit tests and interface from your PR, but made the following changes while integrating the two.
The error checking was why I ended up going with a module pass instead of just wrapping the utility with a function pass. Since the string names would be very convenient to hand-write, they're also the greatest risk of a user typo, which I wanted to guard against. Having a function pass would require either disabling the "all replacements are used", or restricting the updates to a single function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM. Thank you for merging with my PR and making it better!
One comment that is not directly related to this PR.
|
||
|
||
def test_error_with_duplicate_var_names(): | ||
"""Duplicate variable names may not be replaced by string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised to see this case passing the well_formed check. I think this should be filtered before any pass is applied, because this seems incorrect to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the C++ side, both TIR variables and Relax variables have reference identity, and duplicate names can exist within the same function. (Technically, the variables only have a name_hint
, not a name, which I believe is to avoid the assumption of unique names. (Though, double parenthetically, some of the python APIs expose name_hint
as name
.)) TVMScript has a tighter restriction, where variables tracked by the parser may not repeat within the same scope. In this test, I define the variables with the same TIR variable name outside of TVMScript, to produce the ambiguous test case anyways.
As a bit of fun, the function can be inspected, and even though TVMScript modifies the variable names to avoid printing duplicate variables, the underlying objects still have duplicate names.
M = tir.Var("M", "int64")
N1 = tir.Var("N", "int64")
N2 = tir.Var("N", "int64")
dtype = "float32"
@R.function
def func(
A: R.Tensor([M, N2], dtype), B: R.Tensor([N1, N2], dtype)
) -> R.Tensor([M, N1], dtype):
A: R.Tensor([M, N2], dtype) = A
B: R.Tensor([N1, N2], dtype) = B
out = R.linear(A, B)
return out
# Prints B: R.Tensor(("N", "N_1"), dtype="float32")
func.show()
# But the true name can still be extracted
assert func.params[1].struct_info.shape[0].name == "N"
assert func.params[1].struct_info.shape[1].name == "N"
# And the function is well-formed, despite the duplicate variable
# names.
assert tvm.relax.analysis.well_formed(tvm.IRModule.from_expr(func))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, and yeah I see why it currently passes the well-formed checker. But I'm wondering if this is what we want or what it is supposed to be. I feel like this duplicated name would be very confusing if we are debugging things, especially when we print out IRs. Is there any necessity for this feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. It is really useful if you're merging parts of TIR/Relax that were generated independently. Those are guaranteed to have different tir.Var
and relax.Var
objects, but can very easily have duplicate names. I think the only place we de-duplicate the names themselves is in an IRModule
, and even that gets kind of hairy with both String
and GlobalVar
acting as unrelated unique keys.
I think that re-use of the same name should be well-formed, but there should be a utility pass (potentially called several times during lowering) that de-duplicates the name within a function. That way, it provides the better debugging experience, but doesn't require renaming variables when constructing TIR. It also avoids promising that variable names will remain stable, or that they should be relied upon, since that isn't a stable assumption at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, my apology. I missed the part that it already shows renamed varname when printing functions in your example. Just realized by trying by my own. My concern was printing the same name for different variables, which is turned out to be a false alarm. So, I agree with you, this seems okay. Thank you for clarifying!
@sunggg Also, I've added you as a co-author for this PR. |
// Special case for strided_slice | ||
// | ||
// The strided_slice operator currently stores the begins/ends in | ||
// the CallNode::attrs. Because the CallNode::attrs is only | ||
// intended to store static information, any PrimExpr members in | ||
// the attributes are not visited by `ExprMutator::VisitPrimExpr`. | ||
// Therefore, these must be explicitly visited. | ||
// | ||
// When the strided_slice operator is updated to store begins/ends | ||
// as a tuple of `relax::PrimValue` in the arguments, this special | ||
// case can be removed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is worth changing to avoid needing a special case here. Namely, we can use PrimExpr arguments for strided_slice (the reason it's implemented this way is probably because we didn't have PrimExpr at the time we added this op).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, and I'd like to update strided_slice
at some point as well. The comments in the strided_slice
implementation state that it should be updated when relax.PrimValue
is more widely supported, so that is the long-term goal. However, I think that's a larger change, as it would be a change to the Relax IR, and not just a transformation between Relax functions, and should probably be left for a later PR.
// time to collect usages. Otherwise, a symbolic variable | ||
// defined by a later parameter may be treated as undefined when | ||
// used by an earlier parameter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does this happen? I think by definition, the symbolic variables are defined at their first appearance in the signature (going left to right)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The symbolic variables are defined in their first appearance in the signature, but only if that appearance is sufficient to provide a definition. If the symbolic variable can be matched to a specific expression in the caller (e.g. ), then it is bound. If the symbolic variable cannot be matched to a specific expression in the caller, then it is a usage-site instead. So, R.Tensor(shape=[16])
binding to R.Tensor(shape=[sym_var])
implies a binding of {sym_var: 16}
, but R.Tensor(shape=[16])
binding to R.Tensor(shape=[sym_var1 * sym_var2])
cannot provide any definitions.
So, this edge case occurs in cases like the one shown below, where the first occurrence is a usage, and the definition doesn't occur until a later parameter.
def func(A: R.Tensor(shape=[sym_var1*sym_var2]), B: R.Tensor(shape=[sym_var1, sym_var2])):
# ^^^^^^^^ ^^^^^^^^
# Usage of sym_var1 Definition of sym_var1
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, I hadn't thought of that case. Very good.
const auto& var = Downcast<tir::Var>(expr); | ||
if (defined_symbolic_var_.count(var) == 0) { | ||
defined_symbolic_var_.insert(var); | ||
if (mode_ >= VisitMode::kMatchVarDef) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we use this pattern elsewhere? I'm not personally very fond of it because it could create problems for adding more modes later, but I'm not against adding it if we do this already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern of GT/LT enums is used a couple of places, most commonly in the OpPatternKind
when checking if an operation is as permissible or more permissible than required for a transformation. That said, I agree that this isn't the most readable, and it probably would be better to have separate visitors to collect defined variables and to collect free variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making two separate visitors ended up with too much duplicate code, so I instead made VisitMode
explicitly be a bitflag, with kProvideDefinition
and kRequireDefinition
. I think this cleans up the implementation, allowing the behavior to be specified with more granularity, and avoids the use of >=
and <=
for enums.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These seem like solid changes and the implementation is straightforward. Thank you very much! Please have a look at my comments if you have the chance.
Similar to `relax.Function.bind_symbolic_vars`, implemented in apache#15509, this commit introduces `relax.Function.bind_params` to allow Relax parameters to be manipulated on a per-function basis. This utility function and the existing `BindParams` transform both use the same underlying implementation.
* [Unity] Implement relax.Function.bind_symbolic_vars If a function has dynamic shape parameters, it can be useful to replace them with static parameters (e.g. when producing several models within the same family). This commit introduces a utility function `relax.Function.bind_symbolic_vars`, which allows symbolic variables to be replaced with static values. This is a related to the parameter binding done in `relax.transform.BindParam`, but does not require the bound parameter to be fully static data array. * Updating ExprBinder to use tir::Substitute Previously, `ExprBinder` only checked whether a `PrimExpr` was a symbolic variable to be replaced, but did not handle cases where a `PrimExpr` contained a symbolic variable to be replaced. As a result, when binding symbolic variables `{N: 16}`, a shape of `[N,2*N]` would be updated to `[16,2*N]` instead of `[16,32]`. This commit updates `ExprBinder` to use `tir::Substitute` to ensure all occurrences of the symbolic variable are replaced. * Special case for updating symbolic vars in strided_slice attrs * Added IRModule pass to bind symbolic vars * Update unit test to include pytest * Co-authored-by: Sunghyun Park <sunggg@umich.edu> * Correct match mode in kProvideDefinitions context * Clean up implementation with VisitMode as a bitflag
* [Unity] Implement relax.Function.bind_params Similar to `relax.Function.bind_symbolic_vars`, implemented in #15509, this commit introduces `relax.Function.bind_params` to allow Relax parameters to be manipulated on a per-function basis. This utility function and the existing `BindParams` transform both use the same underlying implementation. * Update relay_translator unit tests to avoid duplicate binding * Updated unit test that attempted to bind non-existent parameter
If a function has dynamic shape parameters, it can be useful to replace them with static parameters (e.g. when producing several models within the same family). This commit introduces a utility function
relax.Function.bind_symbolic_vars
, which allows symbolic variables to be replaced with static values.This is a related to the parameter binding done in
relax.transform.BindParam
, but does not require the bound parameter to be fully static data array.