Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] Support pattern matching argmax/argmin generated by TOPI (apach…
Browse files Browse the repository at this point in the history
…e#12827)

This PR introduces two reducers to TIR reduction part, so that rfactor and cross-thread reduction can be applied to those functions who contains argmax/argmin computation generated by TOPI.
  • Loading branch information
MasterJH5574 authored and xinetzone committed Nov 25, 2022
1 parent 856bdcd commit 0e504a5
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 57 deletions.
134 changes: 80 additions & 54 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,60 +297,86 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
*/
struct ReducerRegistry {
ReducerRegistry()
: reducer_getters{CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{x[0] + y[0]};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, 0)};
}),
CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{x[0] * y[0]};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, 1)};
}),
CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{min(x[0], y[0])};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{max_value(values[0]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{max(x[0], y[0])};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{min_value(values[0]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
return Array<PrimExpr>{idx, val};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
min_value(values[1]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
return Array<PrimExpr>{idx, val};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
max_value(values[1]->dtype)};
})} {}
: reducer_getters{
CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{x[0] + y[0]};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, 0)};
}),
CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{x[0] * y[0]};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, 1)};
}),
CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{min(x[0], y[0])};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{max_value(values[0]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/1,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{max(x[0], y[0])};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{min_value(values[0]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
return Array<PrimExpr>{idx, val};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
min_value(values[1]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
PrimExpr idx =
Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))),
x[0], y[0]);
PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]);
return Array<PrimExpr>{idx, val};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
min_value(values[1]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
return Array<PrimExpr>{idx, val};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
max_value(values[1]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
PrimExpr idx = Select(
Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]);
PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]);
return Array<PrimExpr>{idx, val};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
max_value(values[1]->dtype)};
})} {}

static void RegisterReducer(
int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)> combiner_getter,
Expand Down
156 changes: 153 additions & 3 deletions tests/python/unittest/test_tir_schedule_rfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys

import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm import te, tir, topi
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip

Expand Down Expand Up @@ -1133,6 +1131,128 @@ def argmin_split_rfactor(
argmin_v1[i] = v_argmin_v1


@T.prim_func
def argmax_topi_rfactor(
placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"]
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
for i0, i1_0, i1_1 in T.grid(1, 4, 8):
with T.block("placeholder_red_temp_rf"):
vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
with T.init():
placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
placeholder_red_temp_v1_rf[ax0, vi1_1] = -2147483648
v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1]
or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1]
and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
placeholder_red_temp_v0_rf[ax0, vi1_1],
vi1_0 * 8 + vi1_1,
)
v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1],
placeholder_red_temp_v1_rf[ax0, vi1_1],
placeholder[ax0, vi1_0 * 8 + vi1_1],
)
placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf
placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf
for i0, i1_1 in T.grid(1, 8):
with T.block("placeholder_red_temp"):
vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0])
with T.init():
placeholder_red_temp_v0[ax0] = -1
placeholder_red_temp_v1[ax0] = -2147483648
v_placeholder_red_temp_v0: T.int32 = T.Select(
placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1]
or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1]
and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1],
placeholder_red_temp_v0[ax0],
placeholder_red_temp_v0_rf[ax0, vi1_1],
)
v_placeholder_red_temp_v1: T.int32 = T.Select(
placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1],
placeholder_red_temp_v1[ax0],
placeholder_red_temp_v1_rf[ax0, vi1_1],
)
placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
for i0 in T.serial(1):
with T.block("placeholder_red"):
ax0 = T.axis.spatial(1, i0)
T.reads(placeholder_red_temp_v0[ax0])
T.writes(placeholder_red[ax0])
placeholder_red[ax0] = placeholder_red_temp_v0[ax0]


@T.prim_func
def argmin_topi_rfactor(
placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"]
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
for i0, i1_0, i1_1 in T.grid(1, 4, 8):
with T.block("placeholder_red_temp_rf"):
vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
with T.init():
placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
placeholder_red_temp_v1_rf[ax0, vi1_1] = 2147483647
v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1]
or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1]
and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
placeholder_red_temp_v0_rf[ax0, vi1_1],
vi1_0 * 8 + vi1_1,
)
v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1],
placeholder_red_temp_v1_rf[ax0, vi1_1],
placeholder[ax0, vi1_0 * 8 + vi1_1],
)
placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf
placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf
for i0, i1_1 in T.grid(1, 8):
with T.block("placeholder_red_temp"):
vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0])
with T.init():
placeholder_red_temp_v0[ax0] = -1
placeholder_red_temp_v1[ax0] = 2147483647
v_placeholder_red_temp_v0: T.int32 = T.Select(
placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1]
or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1]
and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1],
placeholder_red_temp_v0[ax0],
placeholder_red_temp_v0_rf[ax0, vi1_1],
)
v_placeholder_red_temp_v1: T.int32 = T.Select(
placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1],
placeholder_red_temp_v1[ax0],
placeholder_red_temp_v1_rf[ax0, vi1_1],
)
placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
for i0 in T.serial(1):
with T.block("placeholder_red"):
ax0 = T.axis.spatial(1, i0)
T.reads(placeholder_red_temp_v0[ax0])
T.writes(placeholder_red[ax0])
placeholder_red[ax0] = placeholder_red_temp_v0[ax0]


# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg


Expand Down Expand Up @@ -1490,5 +1610,35 @@ def test_reduction_rfactor_argmax_init_buffer_not_match():
s.rfactor(ki, 1)


def test_reduction_rfactor_topi_argmax():
A = te.placeholder((1, 32), dtype="int32")
B = topi.argmax(A, axis=1)
argmax_topi = te.create_prim_func([A, B])
s = tir.Schedule(argmax_topi, debug_mask="all")
argmax = s.get_block("placeholder_red_temp")
_, k = s.get_loops(argmax)
_, ki = s.split(k, [None, 8])
rf_block = s.rfactor(ki, 1)
tvm.ir.assert_structural_equal(s.mod["main"], argmax_topi_rfactor)
assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp")))
verify_trace_roundtrip(s, mod=argmax_topi)


def test_reduction_rfactor_topi_argmin():
A = te.placeholder((1, 32), dtype="int32")
B = topi.argmin(A, axis=1)
argmin_topi = te.create_prim_func([A, B])
s = tir.Schedule(argmin_topi, debug_mask="all")
argmin = s.get_block("placeholder_red_temp")
_, k = s.get_loops(argmin)
_, ki = s.split(k, [None, 8])
rf_block = s.rfactor(ki, 1)
tvm.ir.assert_structural_equal(s.mod["main"], argmin_topi_rfactor)
assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp")))
verify_trace_roundtrip(s, mod=argmin_topi)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 0e504a5

Please sign in to comment.