Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Transform] Preserve symbolic variables in FuseOps #16637

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, the CompositeFunctionAnnotator visited the body of functions without the parameters being considered in-scope. As a result, EraseToWellDefined would remove known shapes from the function body's StructInfo.

Prior to this commit, the `CompositeFunctionAnnotator` visited the
body of functions without the parameters being considered in-scope.
As a result, `EraseToWellDefined` would remove known shapes from the
function body's `StructInfo`.
@Lunderberg Lunderberg changed the title [Unity][Transform] Preserve symbolic variables in FuseOps [Bugfix][Transform] Preserve symbolic variables in FuseOps Feb 23, 2024
@Lunderberg
Copy link
Contributor Author

This PR was originally part of #16450. As some of the later changes in that PR required further discussion and potential refactoring, the separate PRs allow the bugfix in this commit to land separately.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good fix to have.

@Lunderberg Lunderberg merged commit e56c5e1 into apache:main Feb 29, 2024
20 checks passed
@Lunderberg Lunderberg deleted the transform_preserve_symbolic_vars_in_fuse_ops branch February 29, 2024 14:08
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Mar 12, 2024
)

[Unity][Transform] Preserve symbolic variables in FuseOps

Prior to this commit, the `CompositeFunctionAnnotator` visited the
body of functions without the parameters being considered in-scope.
As a result, `EraseToWellDefined` would remove known shapes from the
function body's `StructInfo`.
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Jun 10, 2024
This is a follow-up commit to
apache#16637, which updated
`relax.transform.FuseOps` to provide additional parameters defining
symbolic variables required by the fused functions.  While this
ensures that `relax.transform.FuseOps` produces well-formed Relax
functions, these additional arguments can break some kernel
implementations.

This commit implements a new transform
`RemoveSymbolicExpressionsInSubroutine` to resolve this issue.  This
transform identifies function arguments whose sole purpose is to
compute a symbolic expression, when that symbolic expression could be
inferred from tensor shapes.

For example, consider the following Relax function:

```python
@R.function
def func(
    data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):

    batch_size = T.int64()
    seq_len = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights)
    return output
```

The `data` tensor may be used to infer `hidden_size`, but cannot be
used to infer `batch_size` or `seq_len`.  The `R.Shape` parameter
exists solely to define `batch_size` and `seq_len`, since all symbolic
variables must be defined.  However, neither `batch_size` nor
`seq_len` are ever used outside of the expression `batch_size *
seq_len`, and the value of `batch_size * seq_len` could be inferred
from the shape of the `data` tensor.

This new transform identifies cases where an argument is otherwise
unnecessary, and replaces the symbolic expression with a new
argument.  This makes the `dummy_arg: R.Shape` be entirely unused, so
a later use of `relax.transform.RemoveUnusedParameters()` can remove
the parameter altogether.

```python
@R.function
def func(
    data: R.Tensor(["data_dim0", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ):

    data_dim0 = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights)
    return output
```
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Sep 11, 2024
This is a follow-up commit to
apache#16637, which updated
`relax.transform.FuseOps` to provide additional parameters defining
symbolic variables required by the fused functions.  While this
ensures that `relax.transform.FuseOps` produces well-formed Relax
functions, these additional arguments can break some kernel
implementations.

This commit implements a new transform
`RemoveSymbolicExpressionsInSubroutine` to resolve this issue.  This
transform identifies function arguments whose sole purpose is to
compute a symbolic expression, when that symbolic expression could be
inferred from tensor shapes.

For example, consider the following Relax function:

```python
@R.function
def func(
    data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):

    batch_size = T.int64()
    seq_len = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights)
    return output
```

The `data` tensor may be used to infer `hidden_size`, but cannot be
used to infer `batch_size` or `seq_len`.  The `R.Shape` parameter
exists solely to define `batch_size` and `seq_len`, since all symbolic
variables must be defined.  However, neither `batch_size` nor
`seq_len` are ever used outside of the expression `batch_size *
seq_len`, and the value of `batch_size * seq_len` could be inferred
from the shape of the `data` tensor.

This new transform identifies cases where an argument is otherwise
unnecessary, and replaces the symbolic expression with a new
argument.  This makes the `dummy_arg: R.Shape` be entirely unused, so
a later use of `relax.transform.RemoveUnusedParameters()` can remove
the parameter altogether.

```python
@R.function
def func(
    data: R.Tensor(["data_dim0", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ):

    data_dim0 = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights)
    return output
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants