Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule] Relax conditions of rule Cross-Thread Reduction (apach…
Browse files Browse the repository at this point in the history
…e#12825)

This PR relaxes the conditions of Meta-Schedule schedule rule CrossThreadReduction. The rules are previously a bit over-strict, and some workloads with small reduction loop length are unable to be optimized by cross-thread reduction automatically. In this PR, we relax the rules so that such workloads can be optimized.
  • Loading branch information
MasterJH5574 authored and xinetzone committed Nov 25, 2022
1 parent f9de7aa commit 9df2f54
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1640,11 +1640,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
if (NeedsMultiLevelTiling(self, block_sref)) {
// Do not use rfactor/cross-thread-reduction if we have enough parallelism on spatial loops.
return !(cum_space_len >= cum_reduce_len || cum_space_len > max_parallel_extent);
} else if (cum_reduce_len > 1) {
// Always try rfactor/cross-thread-reduction for other reduction blocks.
return cum_reduce_len > max_parallel_basic;
} else {
return false;
// Always try rfactor/cross-thread-reduction for other reduction blocks.
return cum_reduce_len > 1;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,28 @@ def argmax(
argmax_v1[i] = v_argmax_v1


@T.prim_func
def argmax_32(
idx: T.Buffer[(1, 32), "int32"],
val: T.Buffer[(1, 32), "float32"],
argmax_v0: T.Buffer[(1,), "int32"],
argmax_v1: T.Buffer[(1,), "float32"],
) -> None:
for i0, i1 in T.grid(1, 32):
with T.block("argmax"):
i = T.axis.spatial(1, i0)
k = T.axis.reduce(32, i1)
T.reads(idx[i, k], val[i, k])
T.writes(argmax_v0[i], argmax_v1[i])
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.min_value("float32")
v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k])
v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k])
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1


def test_gpu_argmax():
@T.prim_func
def argmax_0(
Expand Down Expand Up @@ -663,8 +685,84 @@ def argmax_1(
)


def test_gpu_argmax_32():
@T.prim_func
def argmax_0(
idx: T.Buffer[(1, 32), "int32"],
val: T.Buffer[(1, 32), "float32"],
argmax_v0: T.Buffer[(1,), "int32"],
argmax_v1: T.Buffer[(1,), "float32"],
) -> None:
# body
# with T.block("root")
for i0, i1 in T.grid(1, 32):
with T.block("argmax"):
i, k = T.axis.remap("SR", [i0, i1])
T.reads(idx[i, k], val[i, k])
T.writes(argmax_v0[i], argmax_v1[i])
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k])
v_argmax_v1: T.float32 = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1

@T.prim_func
def argmax_1(
idx: T.Buffer[(1, 32), "int32"],
val: T.Buffer[(1, 32), "float32"],
argmax_v0: T.Buffer[(1,), "int32"],
argmax_v1: T.Buffer[(1,), "float32"],
) -> None:
# body
# with T.block("root")
for i0, i1_0 in T.grid(1, 1):
for i1_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("argmax"):
i = T.axis.spatial(1, i0)
k = T.axis.reduce(32, i1_0 * 64 + i1_1)
T.where(i1_0 * 64 + i1_1 < 32)
T.reads(idx[i, k], val[i, k])
T.writes(argmax_v0[i], argmax_v1[i])
with T.init():
argmax_v0[i] = -1
argmax_v1[i] = T.float32(-3.4028234663852886e38)
v_argmax_v0: T.int32 = T.Select(
argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]
)
v_argmax_v1: T.float32 = T.Select(
argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]
)
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1

decision_0 = [] # type: ignore
decision_1 = [
("SampleCategorical", 4),
]

mod = argmax_32
actual = ms.TuneContext(
mod=mod,
target=Target("nvidia/geforce-rtx-3090", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
task_name="test",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[argmax_0, argmax_1],
expected_decisions=[decision_0, decision_1],
)


if __name__ == "__main__":
test_gpu_softmax_mn()
test_gpu_softmax_mn_after_inline()
test_gpu_batch_norm_bmn()
test_gpu_argmax()
test_gpu_argmax_32()

0 comments on commit 9df2f54

Please sign in to comment.