diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py index 58022290f7ba..07230296412e 100644 --- a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -19,6 +19,7 @@ import typing from tvm import te, tir, topi +from ..utils import get_layout_transform_fn def batch_flatten_compute(inp: te.Tensor) -> te.Tensor: @@ -66,8 +67,8 @@ def batch_flatten_stir_schedule( sch = tir.Schedule(batch_flatten_func, debug_mask="all") compute = sch.get_block("compute") - sch.transform_layout(compute, inp.name, in_layout) - sch.transform_layout(compute, out.name, out_layout) + sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout)) + sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout)) i, j = sch.get_loops(compute) jout, channel = sch.split(j, [None, inp.shape[3]]) height, width = sch.split(jout, [inp.shape[1], inp.shape[2]]) diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index af6e3de9c350..1ceeb186ab87 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -39,6 +39,16 @@ def nhwc_8h2w32c2w_1d(n, h, w, c): return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2] +def nhwc_1024c_1d(n, h, w, c): + """Return index map for nhwc_1024 1d layout""" + return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] + + +def nc_1024_1d(n, c): + """Return index map for nc_1024 1d layout""" + return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] + + def get_layout_transform_fn(layout): """Return index map function as per the layout string""" if layout == "nhwc-8h2w32c2w-2d": @@ -49,4 +59,8 @@ def get_layout_transform_fn(layout): return n11c_1024c_2d if layout == "n11c-1024c-1d": return n11c_1024c_1d + if layout == "nhwc-1024c-1d": + return nhwc_1024c_1d + if layout == "nc-1d": + return nc_1024_1d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 57a9dff8b424..5d031871509b 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -245,6 +245,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): n, h, w, c = arr_np.shape assert h == 1 and w == 1, "The size of h and w must be 1" return arr_np.reshape([n, 1, 1, c // 1024, 1024]) + if new_layout == "nc-1d": + N, C = arr_np.shape + return arr_np.reshape([N, C // 1024, 1024]) + if new_layout == "nhwc-1024c-1d": + N, H, W, C = arr_np.shape + return arr_np.reshape([N, H, W, C // 1024, 1024]) raise RuntimeError(f"Unexpected new_layout '{new_layout}'") raise RuntimeError(f"Unexpected current_layout '{current_layout}'") diff --git a/tests/python/contrib/test_hexagon/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py similarity index 76% rename from tests/python/contrib/test_hexagon/test_batch_flatten.py rename to tests/python/contrib/test_hexagon/topi/test_batch_flatten.py index d1e7c8143caa..cd7a9ec51591 100644 --- a/tests/python/contrib/test_hexagon/test_batch_flatten.py +++ b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py @@ -25,31 +25,7 @@ from tvm.contrib.hexagon.build import HexagonLauncher from tvm.topi import testing -from .infrastructure import allocate_hexagon_array - - -def n11c_1024c_1d(n, h, w, c): - return [n, h, w, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024] - - -def nc_1024_1d(n, c): - return [n, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024] - - -def transform_numpy(arr_np, layout): - if layout == "nhwc": - return arr_np - elif layout == "n11c-1024c-1d": - N, H, W, C = arr_np.shape - return arr_np.reshape([N, H, W, C // 1024, 1024]) - elif layout == "nc-1d": - N, C = arr_np.shape - return arr_np.reshape([N, C // 1024, 1024]) - - -@tvm.testing.fixture -def transformed_expected_output_np(expected_output_np, output_layout): - return transform_numpy(expected_output_np, output_layout) +from ..infrastructure import allocate_hexagon_array, transform_numpy class BaseTestBatchFlatten: @@ -60,7 +36,7 @@ class BaseTestBatchFlatten: (2, 4, 8, 1024), (2, 3, 5, 2048), ) - input_layout, input_axis_sep = tvm.testing.parameters(("n11c-1024c-1d", [4])) + input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-1d", [4])) output_layout, output_axis_sep = tvm.testing.parameters(("nc-1d", [2])) data_type = tvm.testing.parameter("float16") @@ -89,8 +65,8 @@ def test_batch_flatten( tir_s = sl.batch_flatten_stir_schedule( D, A, - nc_1024_1d, - n11c_1024c_1d, + output_layout, + input_layout, ) func_name = "batch_flatten" with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): @@ -101,8 +77,8 @@ def test_batch_flatten( a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type) ref = np.reshape(a_numpy, output_shape) - input_np_transformed = transform_numpy(a_numpy, input_layout) - ref_np_transformed = transform_numpy(ref, output_layout) + input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout) + ref_np_transformed = transform_numpy(ref, "nhwc", output_layout) a_tvm = allocate_hexagon_array( hexagon_session.device,