diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 4c8d91dd27ef..0a2bcc229924 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -18,14 +18,17 @@ import numpy as np import pytest import tempfile +from typing import Optional import tvm import tvm.testing from tvm import relay +from tvm._ffi import register_func from tvm.meta_schedule import postproc, schedule_rule from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner from tvm import meta_schedule as ms +from tvm.tir.schedule import BlockRV, Schedule from ..infrastructure import get_hexagon_target @@ -184,3 +187,174 @@ def test_resnet50(hexagon_launcher): hexagon_lowered.get_graph_json(), hexagon_lowered.lib ) print(debug_ex.profile(input_name=inp.copy())) + + +def _schedule_packed_8x8x32_conv2d(do_tune: bool): + """Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc, + using 8x8x32 packed layout. + """ + + def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: + if conv2d_block == None: + try: + conv2d_block = sch.get_block("conv2d_NCHWc_int8") + except: + return False + + assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"] + + # Apply scheduling + + post_blocks = sch.get_consumers(conv2d_block) + if len(post_blocks) > 0: + # Fuse all intermediate post ops into the last op. + # This is equivalent to the traverse_inline function used in TE schedules. + while True: + next_post_blocks = [] + for post_block in post_blocks: + next_consumers = sch.get_consumers(post_block) + if len(next_consumers) > 0: + sch.compute_inline(post_block) + next_post_blocks += next_consumers + if len(next_post_blocks) == 0: + assert len(post_blocks) == 1 + outer_block = post_blocks[0] + break + post_blocks = next_post_blocks + else: + outer_block = conv2d_block + + # Move the conv2d mma into the injective post mma compute block + if outer_block != conv2d_block: + loops = sch.get_loops(outer_block) + # TODO(csullivan): Currently does all post conv2d mma steps + # directly after accumulation for one spatial pixel. May + # be desirable to do this with coarser spatial granularity + sch.compute_at(conv2d_block, loops[4]) + + def index_map_nchw32c_nchw8h8w32c(n, c, h, w, c32): + return [n, c, h // 8, w // 8, h % 8, w % 8, c32] + + # Add cache for input and output activation layout transform, + # note that weight is already in correct layout + input_cache = sch.cache_read(conv2d_block, 0, "global") + output_cache = sch.cache_write(outer_block, 0, "global") + # Transform the layout of the input + sch.transform_layout( + conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + ) + # Transform the layout of the int32 accumulator + sch.transform_layout( + conv2d_block, ("write", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + ) + # Transform the layout of the output + sch.transform_layout( + outer_block, ("write", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + ) + return True + + return schedule_fn + + +def tune_packed_8x8x32_template(mod, params, hexagon_launcher): + def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV): + _schedule_packed_8x8x32_conv2d(do_tune=True)(sch, conv2d_block) + return [sch] + + register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32) + + def schedule_conv2d_for_tune(sch: Schedule): + _schedule_packed_8x8x32_conv2d(do_tune=True)(sch) + + # This line is necessary for link-params to take effect during + # task extraction and relay.build(...). + mod = mod.with_attr("executor", executor) + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=20000, + max_trials_per_task=1, + num_trials_per_iter=1, + strategy="replay-trace", + builder=get_hexagon_local_builder(), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + # Apply MS auto scheduling rules for all blocks, but utilize + # the custom block scheduling strategy registered above for + # blocks annotated as `schedule_rule:meta_schedule.conv2d_NCHWc_int8` + # space=ms.space_generator.PostOrderApply( + # f_block_filter=None, + # sch_rules="from-target", + # postprocs=[], + # mutator_probs="from-target", + # ), + # Constrain search space to only be the single + # schedule provided for all blocks. No auto + # scheduling will be possible. + space=ms.space_generator.ScheduleFn( + schedule_conv2d_for_tune, + sch_rules=[], + postprocs=[], + mutator_probs={}, + ), + # Without this, the same workloads with different constant weights + # are treated as distinct tuning tasks. + module_equality="ignore-ndarray", + ) + return ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + ) + + +@pytest.mark.skip("End-to-end tuning is skipped on CI.") +@tvm.testing.requires_hexagon +def test_packed_8x8x32_resnet50(hexagon_launcher): + if not os.path.exists(model_json): + pytest.skip(msg="Run python export_models.py first.") + + with open(model_json, "r") as fi: + mod = tvm.ir.load_json(fi.read()) + + with open(model_params, "rb") as fi: + params = relay.load_param_dict(fi.read()) + inp = np.random.randn(1, 3, 224, 224).astype("float32") + input_name = "image" + + do_tune = True + + if do_tune: + hexagon_lowered = tune_packed_8x8x32_template(mod, params, hexagon_launcher) + else: + with tvm.transform.PassContext(opt_level=3): + hexagon_lowered = relay.build( + mod, + tvm.target.Target(target, host=target), + params=params, + executor=executor, + ) + + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + mod, + tvm.target.Target(target_llvm, host=target_llvm), + params=params, + ) + + with hexagon_launcher.start_session() as session: + graph_mod = session.get_executor_from_factory(hexagon_lowered) + graph_mod.set_input(input_name, inp.copy()) + graph_mod.run() + hexagon_output = graph_mod.get_output(0).numpy() + + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_graph_mod.set_input(input_name, inp.copy()) + llvm_graph_mod.run() + ref_result = llvm_graph_mod.get_output(0).numpy() + + np.testing.assert_allclose(ref_result, hexagon_output, atol=1e-4, rtol=1e-5)