Skip to content

Commit

Permalink
Update tests to use util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
abhikran-quic committed Jun 21, 2022
1 parent d7eed47 commit 89479ba
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 32 deletions.
5 changes: 3 additions & 2 deletions python/tvm/topi/hexagon/slice_ops/batch_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import typing

from tvm import te, tir, topi
from ..utils import get_layout_transform_fn


def batch_flatten_compute(inp: te.Tensor) -> te.Tensor:
Expand Down Expand Up @@ -66,8 +67,8 @@ def batch_flatten_stir_schedule(
sch = tir.Schedule(batch_flatten_func, debug_mask="all")
compute = sch.get_block("compute")

sch.transform_layout(compute, inp.name, in_layout)
sch.transform_layout(compute, out.name, out_layout)
sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
i, j = sch.get_loops(compute)
jout, channel = sch.split(j, [None, inp.shape[3]])
height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ def nhwc_8h2w32c2w_1d(n, h, w, c):
return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]


def nhwc_1024c_1d(n, h, w, c):
"""Return index map for nhwc_1024 1d layout"""
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def nc_1024_1d(n, c):
"""Return index map for nc_1024 1d layout"""
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def get_layout_transform_fn(layout):
"""Return index map function as per the layout string"""
if layout == "nhwc-8h2w32c2w-2d":
Expand All @@ -49,4 +59,8 @@ def get_layout_transform_fn(layout):
return n11c_1024c_2d
if layout == "n11c-1024c-1d":
return n11c_1024c_1d
if layout == "nhwc-1024c-1d":
return nhwc_1024c_1d
if layout == "nc-1d":
return nc_1024_1d
raise RuntimeError(f"Unexpected layout '{layout}'")
6 changes: 6 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
n, h, w, c = arr_np.shape
assert h == 1 and w == 1, "The size of h and w must be 1"
return arr_np.reshape([n, 1, 1, c // 1024, 1024])
if new_layout == "nc-1d":
N, C = arr_np.shape
return arr_np.reshape([N, C // 1024, 1024])
if new_layout == "nhwc-1024c-1d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
raise RuntimeError(f"Unexpected current_layout '{current_layout}'")
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,7 @@
from tvm.contrib.hexagon.build import HexagonLauncher
from tvm.topi import testing

from .infrastructure import allocate_hexagon_array


def n11c_1024c_1d(n, h, w, c):
return [n, h, w, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024]


def nc_1024_1d(n, c):
return [n, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024]


def transform_numpy(arr_np, layout):
if layout == "nhwc":
return arr_np
elif layout == "n11c-1024c-1d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])
elif layout == "nc-1d":
N, C = arr_np.shape
return arr_np.reshape([N, C // 1024, 1024])


@tvm.testing.fixture
def transformed_expected_output_np(expected_output_np, output_layout):
return transform_numpy(expected_output_np, output_layout)
from ..infrastructure import allocate_hexagon_array, transform_numpy


class BaseTestBatchFlatten:
Expand All @@ -60,7 +36,7 @@ class BaseTestBatchFlatten:
(2, 4, 8, 1024),
(2, 3, 5, 2048),
)
input_layout, input_axis_sep = tvm.testing.parameters(("n11c-1024c-1d", [4]))
input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-1d", [4]))
output_layout, output_axis_sep = tvm.testing.parameters(("nc-1d", [2]))
data_type = tvm.testing.parameter("float16")

Expand Down Expand Up @@ -89,8 +65,8 @@ def test_batch_flatten(
tir_s = sl.batch_flatten_stir_schedule(
D,
A,
nc_1024_1d,
n11c_1024c_1d,
output_layout,
input_layout,
)
func_name = "batch_flatten"
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}):
Expand All @@ -101,8 +77,8 @@ def test_batch_flatten(
a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
ref = np.reshape(a_numpy, output_shape)

input_np_transformed = transform_numpy(a_numpy, input_layout)
ref_np_transformed = transform_numpy(ref, output_layout)
input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)

a_tvm = allocate_hexagon_array(
hexagon_session.device,
Expand Down

0 comments on commit 89479ba

Please sign in to comment.