Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Add E2E test demonstrating how to apply blocked layout schedule to conv2d via metaschedule #13180

Merged
merged 8 commits into from
Oct 31, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import numpy as np
import pytest
import tempfile
from typing import Optional

import tvm
import tvm.testing
from tvm import relay
from tvm._ffi import register_func
from tvm.meta_schedule import postproc, schedule_rule
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN
from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner
from tvm import meta_schedule as ms
from tvm.tir.schedule import BlockRV, Schedule
from ..infrastructure import get_hexagon_target


Expand Down Expand Up @@ -184,3 +187,174 @@ def test_resnet50(hexagon_launcher):
hexagon_lowered.get_graph_json(), hexagon_lowered.lib
)
print(debug_ex.profile(input_name=inp.copy()))


def _schedule_packed_8x8x32_conv2d(do_tune: bool):
"""Manually schedule a conv2d block, created from TE compute op via CreatePrimFunc,
using 8x8x32 packed layout.
"""

def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
if conv2d_block == None:
try:
conv2d_block = sch.get_block("conv2d_NCHWc_int8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My PR #13206 has has_block(sch, block_name) utility, which can remove this hack. I'll update this code after it is merged.

except:
return False

assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"]

# Apply scheduling

post_blocks = sch.get_consumers(conv2d_block)
if len(post_blocks) > 0:
# Fuse all intermediate post ops into the last op.
# This is equivalent to the traverse_inline function used in TE schedules.
while True:
next_post_blocks = []
for post_block in post_blocks:
next_consumers = sch.get_consumers(post_block)
if len(next_consumers) > 0:
sch.compute_inline(post_block)
next_post_blocks += next_consumers
if len(next_post_blocks) == 0:
assert len(post_blocks) == 1
outer_block = post_blocks[0]
break
post_blocks = next_post_blocks
else:
outer_block = conv2d_block

# Move the conv2d mma into the injective post mma compute block
if outer_block != conv2d_block:
loops = sch.get_loops(outer_block)
# TODO(csullivan): Currently does all post conv2d mma steps
# directly after accumulation for one spatial pixel. May
# be desirable to do this with coarser spatial granularity
sch.compute_at(conv2d_block, loops[4])

def index_map_nchw32c_nchw8h8w32c(n, c, h, w, c32):
return [n, c, h // 8, w // 8, h % 8, w % 8, c32]

# Add cache for input and output activation layout transform,
# note that weight is already in correct layout
input_cache = sch.cache_read(conv2d_block, 0, "global")
output_cache = sch.cache_write(outer_block, 0, "global")
# Transform the layout of the input
sch.transform_layout(
conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0
)
# Transform the layout of the int32 accumulator
sch.transform_layout(
conv2d_block, ("write", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0
)
# Transform the layout of the output
sch.transform_layout(
outer_block, ("write", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0
)
return True

return schedule_fn


def tune_packed_8x8x32_template(mod, params, hexagon_launcher):
def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block: BlockRV):
_schedule_packed_8x8x32_conv2d(do_tune=True)(sch, conv2d_block)
return [sch]

register_func("meta_schedule.conv2d_NCHWc_int8", schedule_rule_conv2d_packed_8x8x32)

def schedule_conv2d_for_tune(sch: Schedule):
_schedule_packed_8x8x32_conv2d(do_tune=True)(sch)

# This line is necessary for link-params to take effect during
# task extraction and relay.build(...).
mod = mod.with_attr("executor", executor)

with tempfile.TemporaryDirectory() as work_dir:
database = ms.relay_integration.tune_relay(
mod=mod,
target=target,
params=params,
work_dir=work_dir,
max_trials_global=20000,
max_trials_per_task=1,
num_trials_per_iter=1,
strategy="replay-trace",
builder=get_hexagon_local_builder(),
runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
# Apply MS auto scheduling rules for all blocks, but utilize
# the custom block scheduling strategy registered above for
# blocks annotated as `schedule_rule:meta_schedule.conv2d_NCHWc_int8`
# space=ms.space_generator.PostOrderApply(
# f_block_filter=None,
# sch_rules="from-target",
# postprocs=[],
# mutator_probs="from-target",
# ),
# Constrain search space to only be the single
# schedule provided for all blocks. No auto
# scheduling will be possible.
space=ms.space_generator.ScheduleFn(
schedule_conv2d_for_tune,
sch_rules=[],
postprocs=[],
mutator_probs={},
),
# Without this, the same workloads with different constant weights
# are treated as distinct tuning tasks.
module_equality="ignore-ndarray",
)
return ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=target,
params=params,
)


@pytest.mark.skip("End-to-end tuning is skipped on CI.")
@tvm.testing.requires_hexagon
def test_packed_8x8x32_resnet50(hexagon_launcher):
if not os.path.exists(model_json):
pytest.skip(msg="Run python export_models.py first.")

with open(model_json, "r") as fi:
mod = tvm.ir.load_json(fi.read())

with open(model_params, "rb") as fi:
params = relay.load_param_dict(fi.read())
inp = np.random.randn(1, 3, 224, 224).astype("float32")
input_name = "image"

do_tune = True

if do_tune:
hexagon_lowered = tune_packed_8x8x32_template(mod, params, hexagon_launcher)
else:
with tvm.transform.PassContext(opt_level=3):
hexagon_lowered = relay.build(
mod,
tvm.target.Target(target, host=target),
params=params,
executor=executor,
)

with tvm.transform.PassContext(opt_level=3):
llvm_lowered = tvm.relay.build(
mod,
tvm.target.Target(target_llvm, host=target_llvm),
params=params,
)

with hexagon_launcher.start_session() as session:
graph_mod = session.get_executor_from_factory(hexagon_lowered)
graph_mod.set_input(input_name, inp.copy())
graph_mod.run()
hexagon_output = graph_mod.get_output(0).numpy()

llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
llvm_graph_mod.set_input(input_name, inp.copy())
llvm_graph_mod.run()
ref_result = llvm_graph_mod.get_output(0).numpy()

np.testing.assert_allclose(ref_result, hexagon_output, atol=1e-4, rtol=1e-5)