From e09ae693c23c537279dfb0c60edfdbf4171ed872 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 16 Sep 2022 14:12:57 -0700 Subject: [PATCH 1/2] Support TOPI argmax/argmin op definition --- src/tir/schedule/primitive/reduction.cc | 134 ++++++++++++++---------- 1 file changed, 80 insertions(+), 54 deletions(-) diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 2dc47fa15bea..dd2bcf727c40 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -297,60 +297,86 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, */ struct ReducerRegistry { ReducerRegistry() - : reducer_getters{CreateReducerGetter( - /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{x[0] + y[0]}; - }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 0)}; - }), - CreateReducerGetter( - /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{x[0] * y[0]}; - }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 1)}; - }), - CreateReducerGetter( - /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{min(x[0], y[0])}; - }, - [](const Array& values) { - return Array{max_value(values[0]->dtype)}; - }), - CreateReducerGetter( - /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{max(x[0], y[0])}; - }, - [](const Array& values) { - return Array{min_value(values[0]->dtype)}; - }), - CreateReducerGetter( - /*n_buffers=*/2, - [](const Array& x, const Array& y) { - PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]); - PrimExpr val = Select(x[1] >= y[1], x[1], y[1]); - return Array{idx, val}; - }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - min_value(values[1]->dtype)}; - }), - CreateReducerGetter( - /*n_buffers=*/2, - [](const Array& x, const Array& y) { - PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]); - PrimExpr val = Select(x[1] <= y[1], x[1], y[1]); - return Array{idx, val}; - }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - max_value(values[1]->dtype)}; - })} {} + : reducer_getters{ + CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{x[0] + y[0]}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, 0)}; + }), + CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{x[0] * y[0]}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, 1)}; + }), + CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{min(x[0], y[0])}; + }, + [](const Array& values) { + return Array{max_value(values[0]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{max(x[0], y[0])}; + }, + [](const Array& values) { + return Array{min_value(values[0]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/2, + [](const Array& x, const Array& y) { + PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]); + PrimExpr val = Select(x[1] >= y[1], x[1], y[1]); + return Array{idx, val}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, -1), + min_value(values[1]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/2, + [](const Array& x, const Array& y) { + PrimExpr idx = + Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), + x[0], y[0]); + PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]); + return Array{idx, val}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, -1), + min_value(values[1]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/2, + [](const Array& x, const Array& y) { + PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]); + PrimExpr val = Select(x[1] <= y[1], x[1], y[1]); + return Array{idx, val}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, -1), + max_value(values[1]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/2, + [](const Array& x, const Array& y) { + PrimExpr idx = Select( + Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]); + PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]); + return Array{idx, val}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, -1), + max_value(values[1]->dtype)}; + })} {} static void RegisterReducer( int n_buffers, TypedPackedFunc(Array, Array)> combiner_getter, From 40774f29aa584653a6283c734a753e663f78bd9d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 16 Sep 2022 14:46:17 -0700 Subject: [PATCH 2/2] Tests --- .../unittest/test_tir_schedule_rfactor.py | 156 +++++++++++++++++- 1 file changed, 153 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index f6db79f3ed23..964fe772d8af 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys - import pytest import tvm import tvm.testing -from tvm import tir +from tvm import te, tir, topi from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -1133,6 +1131,128 @@ def argmin_split_rfactor( argmin_v1[i] = v_argmin_v1 +@T.prim_func +def argmax_topi_rfactor( + placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"] +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32") + placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32") + placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32") + placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32") + for i0, i1_0, i1_1 in T.grid(1, 4, 8): + with T.block("placeholder_red_temp_rf"): + vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) + T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1]) + T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) + with T.init(): + placeholder_red_temp_v0_rf[ax0, vi1_1] = -1 + placeholder_red_temp_v1_rf[ax0, vi1_1] = -2147483648 + v_placeholder_red_temp_v0_rf: T.int32 = T.Select( + placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1] + or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1] + and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1, + placeholder_red_temp_v0_rf[ax0, vi1_1], + vi1_0 * 8 + vi1_1, + ) + v_placeholder_red_temp_v1_rf: T.int32 = T.Select( + placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1], + placeholder_red_temp_v1_rf[ax0, vi1_1], + placeholder[ax0, vi1_0 * 8 + vi1_1], + ) + placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf + placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf + for i0, i1_1 in T.grid(1, 8): + with T.block("placeholder_red_temp"): + vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0]) + T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) + T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0]) + with T.init(): + placeholder_red_temp_v0[ax0] = -1 + placeholder_red_temp_v1[ax0] = -2147483648 + v_placeholder_red_temp_v0: T.int32 = T.Select( + placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1] + or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1] + and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1], + placeholder_red_temp_v0[ax0], + placeholder_red_temp_v0_rf[ax0, vi1_1], + ) + v_placeholder_red_temp_v1: T.int32 = T.Select( + placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1], + placeholder_red_temp_v1[ax0], + placeholder_red_temp_v1_rf[ax0, vi1_1], + ) + placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0 + placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1 + for i0 in T.serial(1): + with T.block("placeholder_red"): + ax0 = T.axis.spatial(1, i0) + T.reads(placeholder_red_temp_v0[ax0]) + T.writes(placeholder_red[ax0]) + placeholder_red[ax0] = placeholder_red_temp_v0[ax0] + + +@T.prim_func +def argmin_topi_rfactor( + placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"] +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32") + placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32") + placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32") + placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32") + for i0, i1_0, i1_1 in T.grid(1, 4, 8): + with T.block("placeholder_red_temp_rf"): + vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) + T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1]) + T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) + with T.init(): + placeholder_red_temp_v0_rf[ax0, vi1_1] = -1 + placeholder_red_temp_v1_rf[ax0, vi1_1] = 2147483647 + v_placeholder_red_temp_v0_rf: T.int32 = T.Select( + placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1] + or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1] + and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1, + placeholder_red_temp_v0_rf[ax0, vi1_1], + vi1_0 * 8 + vi1_1, + ) + v_placeholder_red_temp_v1_rf: T.int32 = T.Select( + placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1], + placeholder_red_temp_v1_rf[ax0, vi1_1], + placeholder[ax0, vi1_0 * 8 + vi1_1], + ) + placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf + placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf + for i0, i1_1 in T.grid(1, 8): + with T.block("placeholder_red_temp"): + vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0]) + T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) + T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0]) + with T.init(): + placeholder_red_temp_v0[ax0] = -1 + placeholder_red_temp_v1[ax0] = 2147483647 + v_placeholder_red_temp_v0: T.int32 = T.Select( + placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1] + or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1] + and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1], + placeholder_red_temp_v0[ax0], + placeholder_red_temp_v0_rf[ax0, vi1_1], + ) + v_placeholder_red_temp_v1: T.int32 = T.Select( + placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1], + placeholder_red_temp_v1[ax0], + placeholder_red_temp_v1_rf[ax0, vi1_1], + ) + placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0 + placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1 + for i0 in T.serial(1): + with T.block("placeholder_red"): + ax0 = T.axis.spatial(1, i0) + T.reads(placeholder_red_temp_v0[ax0]) + T.writes(placeholder_red[ax0]) + placeholder_red[ax0] = placeholder_red_temp_v0[ax0] + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -1490,5 +1610,35 @@ def test_reduction_rfactor_argmax_init_buffer_not_match(): s.rfactor(ki, 1) +def test_reduction_rfactor_topi_argmax(): + A = te.placeholder((1, 32), dtype="int32") + B = topi.argmax(A, axis=1) + argmax_topi = te.create_prim_func([A, B]) + s = tir.Schedule(argmax_topi, debug_mask="all") + argmax = s.get_block("placeholder_red_temp") + _, k = s.get_loops(argmax) + _, ki = s.split(k, [None, 8]) + rf_block = s.rfactor(ki, 1) + tvm.ir.assert_structural_equal(s.mod["main"], argmax_topi_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) + assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp"))) + verify_trace_roundtrip(s, mod=argmax_topi) + + +def test_reduction_rfactor_topi_argmin(): + A = te.placeholder((1, 32), dtype="int32") + B = topi.argmin(A, axis=1) + argmin_topi = te.create_prim_func([A, B]) + s = tir.Schedule(argmin_topi, debug_mask="all") + argmin = s.get_block("placeholder_red_temp") + _, k = s.get_loops(argmin) + _, ki = s.split(k, [None, 8]) + rf_block = s.rfactor(ki, 1) + tvm.ir.assert_structural_equal(s.mod["main"], argmin_topi_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) + assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp"))) + verify_trace_roundtrip(s, mod=argmin_topi) + + if __name__ == "__main__": tvm.testing.main()