Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith][SVE] Add rewrite rules for indices split by scalable expressions #17046

Merged
merged 1 commit into from
Jun 7, 2024

Conversation

Anndrey24
Copy link
Contributor

This commit introduces rewrite rules for indices which can arise from splitting axes by scalable factors (e.g. xo, xi = sch.split(x, factors = [None, 8 * T.vscale()])):

(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) // (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_o
(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) % (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_i

The rewrites help prove checks needed by sch.tensorize() (e.g. CompareBufferRegion).

cc @ekalda @lhutton1

@github-actions github-actions bot requested review from ekalda and lhutton1 May 30, 2024 10:31
src/arith/rewrite_simplify.cc Show resolved Hide resolved
src/arith/rewrite_simplify.cc Show resolved Hide resolved
src/arith/rewrite_simplify.cc Show resolved Hide resolved
@lhutton1
Copy link
Contributor

cc @Lunderberg

Copy link
Contributor

@ekalda ekalda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning it up @Anndrey24! I'm a bit surprised by how permissive the rewrite rules have been towards division by zero, I hope I'm not missing something there. Let's see what the CI does, I hope the stack isn't relying on that behaviour too much 🙈

This commit introduces rewrite rules for indices which can arise from splitting axes by scalable factors (e.g. `xo, xi = sch.split(x, factors = [None, 8 * T.vscale()])`):

```
(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) // (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_o
(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) % (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_i
```

The rewrites help prove checks needed by `sch.tensorize()` (e.g. CompareBufferRegion).
@Anndrey24 Anndrey24 force-pushed the vscale-rewrite-simplify branch from e351cbc to bf66ecd Compare June 6, 2024 10:18
@ekalda
Copy link
Contributor

ekalda commented Jun 7, 2024

We decided to roll back to the initial version of this patch since a large number of Relax tests fail when division by zero is disabled in the rewrite rules (so dealing with this is out of scope for this patch).

@ekalda ekalda merged commit 5d077c5 into apache:main Jun 7, 2024
18 checks passed
@ekalda
Copy link
Contributor

ekalda commented Jun 7, 2024

Thanks @Anndrey24!

@Anndrey24 Anndrey24 deleted the vscale-rewrite-simplify branch June 7, 2024 17:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants