Skip to content

Commit

Permalink
[inductor] Fix torch.split bug on unbacked symint (pytorch#113406)
Browse files Browse the repository at this point in the history
torch.split(x, l) fails when l's shape is the unbacked symint.

E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.

Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes

Pull Request resolved: pytorch#113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
  • Loading branch information
andrewlee302 authored and pytorchmergebot committed Nov 24, 2023
1 parent 5139072 commit cd7d693
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 6 deletions.
15 changes: 15 additions & 0 deletions test/inductor/test_unbacked_symints.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ def fn(x, y):

torch.testing.assert_close(actual, expected)

def test_split_with_sizes(self):
def fn(x, y):
l = y.tolist()
s = torch.split(x, l)
d = l[0] + l[1] + l[2]
return s[0].sum(), d

example_inputs = (torch.randn((32), device="cuda"), torch.tensor((7, 16, 9)))

with dynamo_config.patch({"capture_scalar_outputs": True}):
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)

torch.testing.assert_close(actual, expected)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"

def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"


class OpOverrides:
def __init__(self, parent):
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import ValueRanges

from ..._dynamo.utils import counters
from .. import config, ir, scheduler
from ..codecache import code_hash, get_path, PyCodeCache
Expand Down Expand Up @@ -1143,7 +1144,7 @@ def indexing(
# indirect indexing
cse_var = self.cse.varname_map[var.name]
mask_vars.update(cse_var.mask_vars)
elif var.name.startswith(("s", "ps")):
elif var.name.startswith(("s", "ps", "i")):
pass
else:
# var is one of xN, yN or rN
Expand Down
11 changes: 6 additions & 5 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,11 +2088,12 @@ def create(cls, x, dim, start, end, step=1):
start = cls.handle_negative_index(start, new_size[dim])
end = cls.handle_negative_index(end, new_size[dim])

end = sizevars.evaluate_min(end, new_size[dim])
start = sizevars.evaluate_min(start, end)
if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
sizevars.guard_equals(end, new_size[dim])
return x
if free_unbacked_symbols(start) or free_unbacked_symbols(end):
end = sympy.Min(end, new_size[dim])
start = sympy.Min(start, end)
else:
end = sizevars.evaluate_min(end, new_size[dim])
start = sizevars.evaluate_min(start, end)

new_size[dim] = FloorDiv(end - start + (step - 1), step)

Expand Down
2 changes: 2 additions & 0 deletions torch/fx/experimental/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def __getattr__(self, name: str) -> Any:
"not_": z3.Not,
"floor": self._ops.floor,
"ceil": self._ops.ceil,
"minimum": self._ops.min,
"maximum": self._ops.max,
}

if name in REPLACEMENT:
Expand Down

0 comments on commit cd7d693

Please sign in to comment.