Skip to content

Commit

Permalink
[Fix] Fix the purity flag of "vm.call_tir_dyn" and "kill" ops
Browse files Browse the repository at this point in the history
This PR fixes the purity flag of `relax.vm.call_tir_dyn` and another
few "kill" ops. Their purity flags were set to True, which made them
possible to be removed by `remove_all_unused`.

* `relax.vm.call_tir_dyn` works by mutating the input args in place,
which is not pure.
* though the "kill" ops have no actions so far, their semantics
suggest that they are impure.

A regression test is added to prevent the unexpected removal from
happening again.
  • Loading branch information
MasterJH5574 committed Mar 24, 2024
1 parent 77a7b01 commit cf4d113
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 17 deletions.
15 changes: 8 additions & 7 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -921,8 +921,8 @@ RELAY_REGISTER_OP("relax.memory.kill_storage")
.set_num_inputs(1)
.add_argument("storage", "Expr", "The storage to be killed.")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
// deallocation also isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
// We mark this as impure so it wouldn't be removed by "remove_all_unused"
.set_attr<Bool>("FPurity", Bool(false));

Expr MakeMemKillStorage(Expr storage) {
static const Op& op = Op::Get("relax.memory.kill_storage");
Expand All @@ -937,8 +937,8 @@ RELAY_REGISTER_OP("relax.memory.kill_tensor")
.set_num_inputs(1)
.add_argument("tensor", "Expr", "The tensor to be killed.")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
// memory deallocation also isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
// We mark this as impure so it wouldn't be removed by "remove_all_unused"
.set_attr<Bool>("FPurity", Bool(false));

Expr MakeMemKillTensor(Expr tensor) {
static const Op& op = Op::Get("relax.memory.kill_tensor");
Expand Down Expand Up @@ -1013,8 +1013,8 @@ TVM_REGISTER_OP("relax.vm.kill_object")
.set_num_inputs(1)
.add_argument("obj", "Expr", "The object to be killed.")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
// deallocation also isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));
// We mark this as impure so it wouldn't be removed by "remove_all_unused"
.set_attr<Bool>("FPurity", Bool(false));

Expr MakeVMKillObject(Expr obj) {
static const Op& op = Op::Get("relax.vm.kill_object");
Expand All @@ -1031,7 +1031,8 @@ RELAY_REGISTER_OP("relax.vm.call_tir_dyn")
.add_argument("args", "Tuple",
"The input arguments (list of tensors and last argument is ShapeExpr)")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
.set_attr<Bool>("FPurity", Bool(true));
// "relax.vm.call_tir_dyn" works in an in-place way, which is impure.
.set_attr<Bool>("FPurity", Bool(false));

Expr MakeCallTIRDyn(Expr func, Tuple args) {
static const Op& op = Op::Get("relax.vm.call_tir_dyn");
Expand Down
42 changes: 34 additions & 8 deletions tests/python/relax/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@

import tvm
import tvm.testing
from tvm import tir
from tvm import relax as rx
from tvm import tir
from tvm.relax.analysis import (
has_reshape_pattern,
udchain,
remove_all_unused,
name_to_binding,
all_vars,
all_global_vars,
free_vars,
all_vars,
bound_vars,
free_vars,
has_reshape_pattern,
name_to_binding,
remove_all_unused,
udchain,
)
from tvm.script import relax as R, tir as T
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T


def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]:
Expand Down Expand Up @@ -352,6 +354,30 @@ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)


def test_retain_calls_to_impure_builtin_ops():
@I.ir_module
class Module:
@T.prim_func(private=True)
def my_tir(A: T.handle, B: T.handle, n: T.int64):
T.evaluate(0)

@R.function(pure=False)
def main(x: R.Tensor(("n",), "float32")):
cls = Module
n = T.int64()
storage = R.memory.alloc_storage((n * 4,), 0, "global", "float32")
alloc = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), "float32")
# "call_tir_dyn" is impure which shouldn't be removed.
R.vm.call_tir_dyn(cls.my_tir, (x, alloc, R.shape([n])))
# "kill_tensor"/"kill_storage" are impure which shouldn't be removed.
R.memory.kill_tensor(alloc)
R.memory.kill_storage(storage)
return x

after = remove_all_unused(Module["main"])
tvm.ir.assert_structural_equal(after, Module["main"], map_free_vars=True)


def test_name_to_binding_var_shadowing():
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_transform_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def sum(
def test_do_not_eliminate_dtype():
@I.ir_module
class Before:
@R.function
@R.function(pure=False)
def foo() -> R.Tensor((32, 64), "int32"):
obj: R.Object = R.vm.alloc_storage(
R.shape([24576]), runtime_device_index=0, dtype="uint8"
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,7 +1552,7 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")):


def test_vm_ops():
@R.function
@R.function(pure=False)
def foo(x: R.Tensor(("m", "n"), dtype="float32")):
m = T.int64()
n = T.int64()
Expand Down

0 comments on commit cf4d113

Please sign in to comment.