Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Support narrow dtype for let binding #16947

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
Stmt VisitStmt_(const IfThenElseNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
PrimExpr VisitExpr_(const LTNode* op) override;
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,13 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
tgt = self._get_target(call.struct_info)
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
shape = call.struct_info.shape
# TODO(tvm-team): Support fully dynamic case with `shape=None`
if shape is None:
raise ValueError("non-symbolic shape is not supported for now")
kwargs = {}
if (
(axis == -1 or axis == len(shape) - 1)
shape is not None
and (axis == -1 or axis == len(shape) - 1)
and is_gpu_target(tgt)
and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan")
and call.op.name == "relax.cumsum"
Expand Down
19 changes: 19 additions & 0 deletions src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include <tvm/tir/op.h>

#include "./functor_common.h"
#include "tvm/ir/expr.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include "tvm/tir/var.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -556,6 +560,21 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
}
}

Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) {
LetStmt let_stmt = Downcast<LetStmt>(DataTypeLegalizer::VisitStmt_(op));
if (var_remap_.find(let_stmt->var.get()) == var_remap_.end()) {
return let_stmt;
}
bool is_enabled = is_enabled_;
is_enabled_ = true;
PrimExpr value = VisitExpr(op->value);
Var var = var_remap_[let_stmt->var.get()];
is_enabled_ = is_enabled;
ICHECK(value.dtype() == var.dtype());
// No need to re-visit body
return LetStmt(var, value, let_stmt->body, let_stmt->span);
}

#define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \
bool is_enabled = is_enabled_; \
Expand Down
22 changes: 10 additions & 12 deletions tests/python/relax/test_backend_dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def foo2(y: R.Tensor((2, 3), "float32")):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
workspace = bb.emit(
relax.op.builtin.alloc_tensor(
R.shape([4194568]), R.dtype("uint8"), R.prim_value(0), R.str("global")
R.shape([8388872]), R.dtype("uint8"), R.prim_value(0), R.str("global")
)
)
out = bb.emit_te(
Expand Down Expand Up @@ -400,8 +400,8 @@ def foo(x: R.Tensor((2, 3), "float32", "vulkan")):
assert_structural_equal(mod, expected_mod)


@tvm.testing.requires_cuda
def test_dispatch_cumsum_gpu():
@tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1")
def test_dispatch_cumsum_gpu(target, dev):
"""Test cumsum kernel dispatch and numerical correctness"""

@I.ir_module
Expand All @@ -416,15 +416,13 @@ def main(x: R.Tensor(("m", "n"), "int32")):
size = (8, 2000)
np_data = np.random.randint(0, 10, size).astype("int32")
np_cumsum = np.cumsum(np_data, axis=-1)
for target in ["cuda", "vulkan -supports_int64=1"]:
with tvm.target.Target(target):
mod = DispatchSortScan()(Module)
ex = tvm.relax.build(mod, target)
device = tvm.device(target, 0)
vm = tvm.relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(np_data, device)
cumsum = vm["main"](tvm_data)
tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)
with tvm.target.Target(target):
mod = DispatchSortScan()(Module)
ex = tvm.relax.build(mod, target)
vm = tvm.relax.VirtualMachine(ex, dev)
tvm_data = tvm.nd.array(np_data, dev)
cumsum = vm["main"](tvm_data)
tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,5 +278,30 @@ def main(B: T.Buffer((4,), "int32")):
tvm.ir.assert_structural_equal(Expected, after)


def test_let_binding():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(buf: T.handle):
n = T.int64()
Buf = T.match_buffer(buf, [n], "int32")
ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
for i in T.serial(ceil_log2):
T.evaluate(0)

@tvm.script.ir_module
class Expected:
@T.prim_func
def main(buf: T.handle):
n = T.int32()
Buf = T.match_buffer(buf, [n], "int32")
ceil_log2 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", n))))
for i in range(ceil_log2):
T.evaluate(0)

after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
tvm.ir.assert_structural_equal(Expected, after)


if __name__ == "__main__":
tvm.testing.main()
Loading