Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][MetaSchedule] MultiLevelTilingTensorCore generates inconsistent thread-binding sketch for batched matmul #17012

Merged
merged 2 commits into from
Jun 28, 2024

Conversation

tsu-bin
Copy link
Contributor

@tsu-bin tsu-bin commented May 20, 2024

Below script can be used to reproduce the issue. You may need to run it multiple times to reproduce, because sample_perfect_tile may sometime to hide the issue with some decision.

in_type="float16"
out_type="float16"
BS = 100
MM = 32
NN = 32
KK = 32

def create_batch_matmul(
    b: int = BS, m: int = MM, n: int = NN, k: int = KK, in_dtype: str = in_type, out_dtype: str = out_type
) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
    A = te.placeholder((b, m, k), name="A", dtype=in_dtype)
    B = te.placeholder((b, n, k), name="B", dtype=in_dtype)
    C = topi.nn.batch_matmul(A, B)
    return (A, B, C)

space = meta_schedule.space_generator.PostOrderApply(
        sch_rules="cuda-tensorcore",
        postprocs="cuda-tensorcore",
    )
database = meta_schedule.tune_tir(
    mod=te.create_prim_func( create_batch_matmul () ),
    target=tvm.target.Target("cuda -arch=sm_89 -max_shared_memory_per_block=49152 -max_threads_per_block=1024"),
    max_trials_global = 200,
    space=space,
    work_dir="./test_batch_matmul/",
)

The error log is something like below.

  3: tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}::operator()(tvm::tir::Stmt const&) const
        at /hostShare/tools/tvm_all/tvm-dev/src/tir/ir/stmt_functor.cc:210
  2: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
        at /hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:60
  1: tvm::tir::ThreadBindingUnifier::VisitStmt_(tvm::tir::ForNode const*)
        at /hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:63
  0: tvm::tir::Stmt tvm::tir::ThreadBindingUnifier::UnifyThreadBindingImpl<tvm::tir::ForNode>(tvm::tir::ForNode const*, tvm::tir::Var const&, tvm::tir::IterVar const&, tvm::Range const&)
        at /hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:112
  File "/hostShare/tools/tvm_all/tvm-dev/src/support/parallel_for.cc", line 139
RuntimeError: parallel_for_dynamic error with [22:30:41] /hostShare/tools/tvm_all/tvm-dev/src/tir/transforms/unify_thread_binding.cc:112: Check failed: (ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) is false: ValueError: All loops that are bound to `threadIdx.y` should have the same extent. However, there are two loops with extent 12 and 4, which are not equal

The root cause is, the Batch Loop will be treated the same way as the other two spacial loops, M and N Loops, the Batch Loop will be decomposed following the SSSRRSRS fashion. But MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout has the assumption that the inner most two S should only have M and N loops' segments, which causes that, the following AddWriteReuseTensorCore adds an "wmma.accumulator" cache write block, and fuses some loop vars, which only belong to M and N Loops, and binds them to "threadIdx.y". But the previous, also the first, fused loop bound to "threadIdx.y" contains Batch Loop's segment. So the inconsistency arises.

The fix is simple, just skip the outer Batch Loop from sample_perfect_tile process and fuse it into "blockIdx.y".

Actually I also tried more complex strategy that decomposes Batch Loop into SSS, with each segment binds to "blockIdx.y" "blockIdx.x" "threadIdx.y" separately, so inner most two Ss contain no Batch Loop segment. But this strategy is less performant for several typical workload. I think that's because AddWriteReuseTensorCore will reorder the inner loop var across the Batch Loop segment, which cause less data locality.

Below I also paste the trace before the fix (including postproc trace), you can replay it line by line and print the each loop extent to verify the inconsistency.

b0 = sch.get_block(name="T_batch_matmul_NT", func_name="main")
b1 = sch.get_block(name="root", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
b2 = sch.reindex(block=b0, buffer=("write", 0))
b3 = sch.reindex(block=b0, buffer=("read", 0))
b4 = sch.reindex(block=b0, buffer=("read", 1))
sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda v_b, v_i, v_k: (v_b, v_i, v_k,), pad_value=None, assume_injective_transform=True)
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda v_b, v_j, v_k: (v_b, v_j, v_k,), pad_value=None, assume_injective_transform=True)
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda v_b, v_i, v_j: (v_b, v_i, v_j,), pad_value=None, assume_injective_transform=True)
sch.transform_block_layout(block=b2, index_map=lambda v_b, v_i, v_j: (v_b, v_i, v_j,))
sch.transform_block_layout(block=b3, index_map=lambda v_b, v_i, v_k: (v_b, v_i, v_k,))
sch.transform_block_layout(block=b4, index_map=lambda v_b, v_j, v_k: (v_b, v_j, v_k,))
sch.transform_block_layout(block=b0, index_map=lambda v_b, v_i, v_j, v_k: (v_b, v_i, v_j, v_k,))
l5, l6, l7, l8 = sch.get_loops(block=b0)
l9, l10 = sch.split(loop=l8, factors=[None, 16], preserve_unit_iters=True, disable_predication=False)
l11, l12 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True, disable_predication=False)
l13, l14 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True, disable_predication=False)
l15, l16, l17, l18, l19, l20, l21 = sch.get_loops(block=b0)
sch.reorder(l18, l20, l14, l12, l10)
b22 = sch.blockize(target=l14, preserve_unit_iters=True)
sch.annotate(block_or_loop=b22, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f16_trans")
sch.annotate(block_or_loop=b22, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f16")
sch.annotate(block_or_loop=b22, ann_key="warp_execution", ann_val=1)
l23, l24, l25, l26 = sch.get_loops(block=b22)
v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l23, n=5, max_innermost_factor=4, decision=[10, 1, 2, 5, 1])
l32, l33, l34, l35, l36 = sch.split(loop=l23, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True, disable_predication=False)
v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l24, n=5, max_innermost_factor=4, decision=[1, 1, 2, 1, 1])
l42, l43, l44, l45, l46 = sch.split(loop=l24, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True, disable_predication=False)
v47, v48, v49, v50, v51 = sch.sample_perfect_tile(loop=l25, n=5, max_innermost_factor=4, decision=[2, 1, 1, 1, 1])
l52, l53, l54, l55, l56 = sch.split(loop=l25, factors=[v47, v48, v49, v50, v51], preserve_unit_iters=True, disable_predication=False)
v57, v58, v59 = sch.sample_perfect_tile(loop=l26, n=3, max_innermost_factor=4, decision=[1, 2, 1])
l60, l61, l62 = sch.split(loop=l26, factors=[v57, v58, v59], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l32, l42, l52, l33, l43, l53, l34, l44, l54, l60, l61, l35, l45, l55, l62, l36, l46, l56)
l63 = sch.fuse(l32, l42, l52, preserve_unit_iters=True)
sch.bind(loop=l63, thread_axis="blockIdx.y")
l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True)
sch.bind(loop=l64, thread_axis="blockIdx.x")
l65 = sch.fuse(l34, l44, l54, preserve_unit_iters=True)
sch.bind(loop=l65, thread_axis="threadIdx.y")
sch.annotate(block_or_loop=b22, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)
sch.annotate(block_or_loop=b22, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)
sch.transform_layout(block=b22, buffer=("write", 0), index_map=lambda i0, i1, i2: (i0, i1 // 16 // (v40 * v41), i2 // 16 // (v50 * v51), i1 // 16 % (v40 * v41), i2 // 16 % (v50 * v51), i1 % 16, i2 % 16,), pad_value=None, assume_injective_transform=True)
b66 = sch.cache_write(block=b22, write_buffer_index=0, storage_scope="shared.dyn")
sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True, index=-1)
b67 = sch.cache_write(block=b22, write_buffer_index=0, storage_scope="wmma.accumulator")
l68, l69, l70, l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b66)
sch.reorder(l73, l71, l72, l74)
sch.compute_at(block=b67, loop=l73, preserve_unit_loops=True, index=-1)
l77, l78, l79, l80, l81, l82, l83, l84, l85, l86, l87 = sch.get_loops(block=b67)
l88 = sch.fuse(l82, l83, preserve_unit_iters=True)
sch.bind(loop=l88, thread_axis="threadIdx.y")
sch.reverse_compute_inline(block=b2)
l89, l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b67)
b99 = sch.blockize(target=l97, preserve_unit_iters=True)
sch.annotate(block_or_loop=b99, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f16_shared_dyn")
l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b66)
l109 = sch.fuse(l104, l105, l106, l107, l108, preserve_unit_iters=True)
v110 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0)
sch.annotate(block_or_loop=b66, ann_key="meta_schedule.cooperative_fetch", ann_val=v110)
b111 = sch.cache_read(block=b22, read_buffer_index=0, storage_scope="shared.dyn", consumer_blocks=[b22])
sch.compute_at(block=b111, loop=l60, preserve_unit_loops=True, index=-1)
l112, l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b111)
l119 = sch.fuse(l116, l117, l118, preserve_unit_iters=True)
v120 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1)
sch.annotate(block_or_loop=b111, ann_key="meta_schedule.cooperative_fetch", ann_val=v120)
b121 = sch.cache_read(block=b22, read_buffer_index=1, storage_scope="shared.dyn", consumer_blocks=[b22])
sch.compute_at(block=b121, loop=l60, preserve_unit_loops=True, index=-1)
l122, l123, l124, l125, l126, l127, l128 = sch.get_loops(block=b121)
l129 = sch.fuse(l126, l127, l128, preserve_unit_iters=True)
v130 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2)
sch.annotate(block_or_loop=b121, ann_key="meta_schedule.cooperative_fetch", ann_val=v130)
b131 = sch.cache_read(block=b22, read_buffer_index=0, storage_scope="wmma.matrix_a")
sch.compute_at(block=b131, loop=l61, preserve_unit_loops=True, index=-1)
l132, l133, l134, l135, l136, l137, l138, l139 = sch.get_loops(block=b131)
l140, l141 = sch.split(loop=l139, factors=[None, 16], preserve_unit_iters=True, disable_predication=False)
l142, l143 = sch.split(loop=l138, factors=[None, 16], preserve_unit_iters=True, disable_predication=False)
l144, l145, l146, l147, l148, l149, l150, l151, l152, l153 = sch.get_loops(block=b131)
sch.reorder(l152, l143, l141)
b154 = sch.blockize(target=l143, preserve_unit_iters=True)
sch.annotate(block_or_loop=b154, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a_shared_dyn")
b155 = sch.cache_read(block=b22, read_buffer_index=1, storage_scope="wmma.matrix_b")
sch.compute_at(block=b155, loop=l61, preserve_unit_loops=True, index=-1)
l156, l157, l158, l159, l160, l161, l162, l163 = sch.get_loops(block=b155)
l164, l165 = sch.split(loop=l163, factors=[None, 16], preserve_unit_iters=True, disable_predication=False)
l166, l167 = sch.split(loop=l162, factors=[None, 16], preserve_unit_iters=True, disable_predication=False)
l168, l169, l170, l171, l172, l173, l174, l175, l176, l177 = sch.get_loops(block=b155)
sch.reorder(l176, l167, l165)
b178 = sch.blockize(target=l167, preserve_unit_iters=True)
sch.annotate(block_or_loop=b178, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b_trans_shared_dyn")
b179, = sch.get_producers(block=b111)
sch.compute_inline(block=b179)
sch.storage_align(block=b111, buffer_index=0, axis=-2, factor=32, offset=8)
b180, = sch.get_producers(block=b121)
sch.compute_inline(block=b180)
sch.storage_align(block=b121, buffer_index=0, axis=-2, factor=32, offset=8)
v181 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3)
sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v181)
sch.enter_postproc()
sch.unannotate(block_or_loop=b66, ann_key="meta_schedule.cooperative_fetch")
l182, l183, l184, l185, l186 = sch.get_loops(block=b66)
l187, l188, l189 = sch.split(loop=l186, factors=[None, 2, 32], preserve_unit_iters=True, disable_predication=False)
sch.bind(loop=l189, thread_axis="threadIdx.x")
sch.bind(loop=l188, thread_axis="threadIdx.y")
sch.unannotate(block_or_loop=b111, ann_key="meta_schedule.cooperative_fetch")
l190, l191, l192, l193, l194 = sch.get_loops(block=b111)
l195, l196, l197, l198 = sch.split(loop=l194, factors=[None, 2, 32, 2], preserve_unit_iters=True, disable_predication=False)
sch.vectorize(loop=l198)
sch.bind(loop=l197, thread_axis="threadIdx.x")
sch.bind(loop=l196, thread_axis="threadIdx.y")
sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.cooperative_fetch")
l199, l200, l201, l202, l203 = sch.get_loops(block=b121)
l204, l205, l206, l207 = sch.split(loop=l203, factors=[None, 2, 32, 4], preserve_unit_iters=True, disable_predication=False)
sch.vectorize(loop=l207)
sch.bind(loop=l206, thread_axis="threadIdx.x")
sch.bind(loop=l205, thread_axis="threadIdx.y")
b208 = sch.get_block(name="root", func_name="main")
sch.unannotate(block_or_loop=b208, ann_key="meta_schedule.unroll_explicit")
b209, b210, b211, b212, b213, b214, b215 = sch.get_child_blocks(b208)
l216, l217, l218, l219, l220, l221, l222, l223 = sch.get_loops(block=b209)
l224, l225, l226, l227, l228, l229, l230, l231 = sch.get_loops(block=b210)
l232, l233, l234, l235, l236, l237, l238, l239 = sch.get_loops(block=b211)
l240, l241, l242, l243, l244, l245, l246, l247 = sch.get_loops(block=b212)
l248, l249, l250, l251, l252, l253, l254, l255, l256, l257, l258, l259 = sch.get_loops(block=b213)
sch.annotate(block_or_loop=l248, ann_key="pragma_auto_unroll_max_step", ann_val=512)
sch.annotate(block_or_loop=l248, ann_key="pragma_unroll_explicit", ann_val=1)
l260, l261, l262, l263, l264, l265, l266, l267 = sch.get_loops(block=b214)
l268, l269, l270, l271, l272, l273, l274 = sch.get_loops(block=b215)
b275 = sch.get_block(name="T_batch_matmul_NT_o", func_name="main")
l276, l277, l278, l279, l280, l281, l282, l283, l284, l285, l286, l287 = sch.get_loops(block=b275)
b288 = sch.decompose_reduction(block=b275, loop=l279)
sch.unannotate(block_or_loop=b288, ann_key="meta_schedule.auto_tensorize")
sch.annotate(block_or_loop=b288, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_fill_16x16x16_f16")
sch.unannotate(block_or_loop=b275, ann_key="meta_schedule.auto_tensorize_init")
sch.unannotate(block_or_loop=b288, ann_key="meta_schedule.auto_tensorize_init")
b289 = sch.get_block(name="T_batch_matmul_NT_o_init", func_name="main")
sch.unannotate(block_or_loop=b289, ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b289, tensor_intrin="wmma_fill_16x16x16_f16", preserve_unit_iters=True)
b290 = sch.get_block(name="A_reindex_shared.dyn_wmma.matrix_a_o", func_name="main")
sch.unannotate(block_or_loop=b290, ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b290, tensor_intrin="wmma_load_16x16x16_f16_a_shared_dyn", preserve_unit_iters=True)
b291 = sch.get_block(name="B_reindex_shared.dyn_wmma.matrix_b_o", func_name="main")
sch.unannotate(block_or_loop=b291, ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b291, tensor_intrin="wmma_load_16x16x16_f16_b_trans_shared_dyn", preserve_unit_iters=True)
b292 = sch.get_block(name="T_batch_matmul_NT_o_update", func_name="main")
sch.unannotate(block_or_loop=b292, ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b292, tensor_intrin="wmma_sync_16x16x16_f16f16f16_trans", preserve_unit_iters=True)
b293 = sch.get_block(name="T_batch_matmul_NT_reindex_shared.dyn_wmma.accumulator_o", func_name="main")
sch.unannotate(block_or_loop=b293, ann_key="meta_schedule.auto_tensorize")
sch.tensorize(block_or_loop=b293, tensor_intrin="wmma_store_16x16x16_f16_shared_dyn", preserve_unit_iters=True)

@tsu-bin tsu-bin force-pushed the debug_ms_batchmatmul branch 9 times, most recently from 5f51d19 to 7c4c620 Compare May 22, 2024 03:09
@tsu-bin tsu-bin marked this pull request as draft May 22, 2024 08:37
@tsu-bin tsu-bin marked this pull request as ready for review May 22, 2024 15:11
@tsu-bin tsu-bin force-pushed the debug_ms_batchmatmul branch from aae484f to 7c4b085 Compare May 22, 2024 15:14
@tsu-bin tsu-bin force-pushed the debug_ms_batchmatmul branch from 7c4b085 to b9a0248 Compare May 23, 2024 04:01
@tsu-bin
Copy link
Contributor Author

tsu-bin commented Jun 4, 2024

@tvm-bot rerun

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Jun 28, 2024

hi @vinx13 could you help to merge this PR, I known relax + dlight is starting to prevail for LLM workloads, but many 'old' workloads like ours still heavily rely on relay+metaschedule.

@vinx13 vinx13 merged commit 4a5e22e into apache:main Jun 28, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants