Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Implement relax.Function.bind_symbolic_vars #15509

Merged
merged 8 commits into from
Aug 16, 2023

Conversation

Lunderberg
Copy link
Contributor

If a function has dynamic shape parameters, it can be useful to replace them with static parameters (e.g. when producing several models within the same family). This commit introduces a utility function relax.Function.bind_symbolic_vars, which allows symbolic variables to be replaced with static values.

This is a related to the parameter binding done in relax.transform.BindParam, but does not require the bound parameter to be fully static data array.

If a function has dynamic shape parameters, it can be useful to
replace them with static parameters (e.g. when producing several
models within the same family).  This commit introduces a utility
function `relax.Function.bind_symbolic_vars`, which allows symbolic
variables to be replaced with static values.

This is a related to the parameter binding done in
`relax.transform.BindParam`, but does not require the bound parameter
to be fully static data array.
@tvm-bot
Copy link
Collaborator

tvm-bot commented Aug 8, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Lunderberg and others added 2 commits August 9, 2023 13:35
Previously, `ExprBinder` only checked whether a `PrimExpr` was a
symbolic variable to be replaced, but did not handle cases where a
`PrimExpr` contained a symbolic variable to be replaced.  As a result,
when binding symbolic variables `{N: 16}`, a shape of `[N,2*N]` would be
updated to `[16,2*N]` instead of `[16,32]`.  This commit updates
`ExprBinder` to use `tir::Substitute` to ensure all occurrences of the
symbolic variable are replaced.
@sunggg
Copy link
Contributor

sunggg commented Aug 9, 2023

Thanks, @Lunderberg! We might have implemented the similar feature haha
Here is my PR #15246 under review.

@Lunderberg
Copy link
Contributor Author

Ooh, thank you for pointing it out! That would explain why I didn't see the functionality anywhere in the unity branch itself.

Looking through it, I think the main differences are:

Thoughts on making a merged PR with the best of both implementations:

  • Providing methods to update either a IRModule transform or a relax Function would provide the most flexibility.
  • Implementing a IRModule transform on top of a relax Function update would be require constructing fewer intermediates than implementing a relax Function update on top of an IRModule transform.

Any preference on which of the two PRs should be updated overall? Since both provide similar functionality, I don't have a huge preference.

@sunggg
Copy link
Contributor

sunggg commented Aug 9, 2023

Thank you for the great summary!
In fact, I like your PR in the sense that it offers broader supports for TIR and it accepts both strings and tir.Var.
Since I think a pass would be convenient when we are dealing with IRModule with multiple functions inside, how about creating a function pass that wraps your utility function? I agree that this would be easier path to merge our PRs.

@Lunderberg
Copy link
Contributor Author

Lunderberg commented Aug 9, 2023

@sunggg How does the IRModule pass in the latest commit look?

I pulled in the unit tests and interface from your PR, but made the following changes while integrating the two.

  • Swapped order of binding_map and func_name, to allow a default value of func_name = None. If no function name is specified, all functions within the module with be updated.
  • Pulled in the error checking from the relax.Function.bind_symbolic_vars for the IRModule pass. If a single function is updated, each value in the binding_map must be used in that function. If multiple functions are updated, each values in the binding_map must occur as a replacement in at least one function.
  • Allows replacement either by TIR variable or by name. If multiple functions are updated, a replacement by name may result in updates to a different TIR variable in each function, so long as it uniquely maps to a TIR variable within a given function. (e.g. Replacing "batch_size" in all functions in a module.)

The error checking was why I ended up going with a module pass instead of just wrapping the utility with a function pass. Since the string names would be very convenient to hand-write, they're also the greatest risk of a user typo, which I wanted to guard against. Having a function pass would require either disabling the "all replacements are used", or restricting the updates to a single function.

Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM. Thank you for merging with my PR and making it better!
One comment that is not directly related to this PR.



def test_error_with_duplicate_var_names():
"""Duplicate variable names may not be replaced by string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised to see this case passing the well_formed check. I think this should be filtered before any pass is applied, because this seems incorrect to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the C++ side, both TIR variables and Relax variables have reference identity, and duplicate names can exist within the same function. (Technically, the variables only have a name_hint, not a name, which I believe is to avoid the assumption of unique names. (Though, double parenthetically, some of the python APIs expose name_hint as name.)) TVMScript has a tighter restriction, where variables tracked by the parser may not repeat within the same scope. In this test, I define the variables with the same TIR variable name outside of TVMScript, to produce the ambiguous test case anyways.

As a bit of fun, the function can be inspected, and even though TVMScript modifies the variable names to avoid printing duplicate variables, the underlying objects still have duplicate names.

M = tir.Var("M", "int64")
N1 = tir.Var("N", "int64")
N2 = tir.Var("N", "int64")

dtype = "float32"


@R.function
def func(
    A: R.Tensor([M, N2], dtype), B: R.Tensor([N1, N2], dtype)
) -> R.Tensor([M, N1], dtype):
    A: R.Tensor([M, N2], dtype) = A
    B: R.Tensor([N1, N2], dtype) = B
    out = R.linear(A, B)
    return out


# Prints  B: R.Tensor(("N", "N_1"), dtype="float32")
func.show()
# But the true name can still be extracted
assert func.params[1].struct_info.shape[0].name == "N"
assert func.params[1].struct_info.shape[1].name == "N"
# And the function is well-formed, despite the duplicate variable
# names.
assert tvm.relax.analysis.well_formed(tvm.IRModule.from_expr(func))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, and yeah I see why it currently passes the well-formed checker. But I'm wondering if this is what we want or what it is supposed to be. I feel like this duplicated name would be very confusing if we are debugging things, especially when we print out IRs. Is there any necessity for this feature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. It is really useful if you're merging parts of TIR/Relax that were generated independently. Those are guaranteed to have different tir.Var and relax.Var objects, but can very easily have duplicate names. I think the only place we de-duplicate the names themselves is in an IRModule, and even that gets kind of hairy with both String and GlobalVar acting as unrelated unique keys.

I think that re-use of the same name should be well-formed, but there should be a utility pass (potentially called several times during lowering) that de-duplicates the name within a function. That way, it provides the better debugging experience, but doesn't require renaming variables when constructing TIR. It also avoids promising that variable names will remain stable, or that they should be relied upon, since that isn't a stable assumption at the moment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my apology. I missed the part that it already shows renamed varname when printing functions in your example. Just realized by trying by my own. My concern was printing the same name for different variables, which is turned out to be a false alarm. So, I agree with you, this seems okay. Thank you for clarifying!

@Lunderberg
Copy link
Contributor Author

@sunggg Also, I've added you as a co-author for this PR.

Comment on lines +71 to +81
// Special case for strided_slice
//
// The strided_slice operator currently stores the begins/ends in
// the CallNode::attrs. Because the CallNode::attrs is only
// intended to store static information, any PrimExpr members in
// the attributes are not visited by `ExprMutator::VisitPrimExpr`.
// Therefore, these must be explicitly visited.
//
// When the strided_slice operator is updated to store begins/ends
// as a tuple of `relax::PrimValue` in the arguments, this special
// case can be removed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is worth changing to avoid needing a special case here. Namely, we can use PrimExpr arguments for strided_slice (the reason it's implemented this way is probably because we didn't have PrimExpr at the time we added this op).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, and I'd like to update strided_slice at some point as well. The comments in the strided_slice implementation state that it should be updated when relax.PrimValue is more widely supported, so that is the long-term goal. However, I think that's a larger change, as it would be a change to the Relax IR, and not just a transformation between Relax functions, and should probably be left for a later PR.

Comment on lines +1007 to +1009
// time to collect usages. Otherwise, a symbolic variable
// defined by a later parameter may be treated as undefined when
// used by an earlier parameter.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When does this happen? I think by definition, the symbolic variables are defined at their first appearance in the signature (going left to right)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The symbolic variables are defined in their first appearance in the signature, but only if that appearance is sufficient to provide a definition. If the symbolic variable can be matched to a specific expression in the caller (e.g. ), then it is bound. If the symbolic variable cannot be matched to a specific expression in the caller, then it is a usage-site instead. So, R.Tensor(shape=[16]) binding to R.Tensor(shape=[sym_var]) implies a binding of {sym_var: 16}, but R.Tensor(shape=[16]) binding to R.Tensor(shape=[sym_var1 * sym_var2]) cannot provide any definitions.

So, this edge case occurs in cases like the one shown below, where the first occurrence is a usage, and the definition doesn't occur until a later parameter.

def func(A: R.Tensor(shape=[sym_var1*sym_var2]), B: R.Tensor(shape=[sym_var1, sym_var2])):
    #                       ^^^^^^^^                                ^^^^^^^^
    #                       Usage of sym_var1                       Definition of sym_var1
    ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I hadn't thought of that case. Very good.

const auto& var = Downcast<tir::Var>(expr);
if (defined_symbolic_var_.count(var) == 0) {
defined_symbolic_var_.insert(var);
if (mode_ >= VisitMode::kMatchVarDef) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use this pattern elsewhere? I'm not personally very fond of it because it could create problems for adding more modes later, but I'm not against adding it if we do this already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern of GT/LT enums is used a couple of places, most commonly in the OpPatternKind when checking if an operation is as permissible or more permissible than required for a transformation. That said, I agree that this isn't the most readable, and it probably would be better to have separate visitors to collect defined variables and to collect free variables.

Copy link
Contributor Author

@Lunderberg Lunderberg Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making two separate visitors ended up with too much duplicate code, so I instead made VisitMode explicitly be a bitflag, with kProvideDefinition and kRequireDefinition. I think this cleans up the implementation, allowing the behavior to be specified with more granularity, and avoids the use of >= and <= for enums.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem like solid changes and the implementation is straightforward. Thank you very much! Please have a look at my comments if you have the chance.

@Lunderberg Lunderberg merged commit 2e2126f into apache:unity Aug 16, 2023
@Lunderberg Lunderberg deleted the unity_bind_symbolic_vars branch August 16, 2023 14:10
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Aug 27, 2023
Similar to `relax.Function.bind_symbolic_vars`, implemented in
apache#15509, this commit introduces
`relax.Function.bind_params` to allow Relax parameters to be
manipulated on a per-function basis.  This utility function and the
existing `BindParams` transform both use the same underlying
implementation.
yongwww pushed a commit to yongwww/tvm that referenced this pull request Aug 28, 2023
* [Unity] Implement relax.Function.bind_symbolic_vars

If a function has dynamic shape parameters, it can be useful to
replace them with static parameters (e.g. when producing several
models within the same family).  This commit introduces a utility
function `relax.Function.bind_symbolic_vars`, which allows symbolic
variables to be replaced with static values.

This is a related to the parameter binding done in
`relax.transform.BindParam`, but does not require the bound parameter
to be fully static data array.

* Updating ExprBinder to use tir::Substitute

Previously, `ExprBinder` only checked whether a `PrimExpr` was a
symbolic variable to be replaced, but did not handle cases where a
`PrimExpr` contained a symbolic variable to be replaced.  As a result,
when binding symbolic variables `{N: 16}`, a shape of `[N,2*N]` would be
updated to `[16,2*N]` instead of `[16,32]`.  This commit updates
`ExprBinder` to use `tir::Substitute` to ensure all occurrences of the
symbolic variable are replaced.

* Special case for updating symbolic vars in strided_slice attrs

* Added IRModule pass to bind symbolic vars

* Update unit test to include pytest

* Co-authored-by: Sunghyun Park <sunggg@umich.edu>

* Correct match mode in kProvideDefinitions context

* Clean up implementation with VisitMode as a bitflag
csullivan pushed a commit that referenced this pull request Sep 6, 2023
* [Unity] Implement relax.Function.bind_params

Similar to `relax.Function.bind_symbolic_vars`, implemented in
#15509, this commit introduces
`relax.Function.bind_params` to allow Relax parameters to be
manipulated on a per-function basis.  This utility function and the
existing `BindParams` transform both use the same underlying
implementation.

* Update relay_translator unit tests to avoid duplicate binding

* Updated unit test that attempted to bind non-existent parameter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants