Skip to content

Commit

Permalink
Update tests to use util functions from infrastructure.py
Browse files Browse the repository at this point in the history
  • Loading branch information
abhikran-quic committed Jun 16, 2022
1 parent d7eed47 commit 9338ce3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
6 changes: 6 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
n, h, w, c = arr_np.shape
assert h == 1 and w == 1, "The size of h and w must be 1"
return arr_np.reshape([n, 1, 1, c // 1024, 1024])
if new_layout == "nc-1d":
N, C = arr_np.shape
return arr_np.reshape([N, C // 1024, 1024])
if new_layout == "nhwc-1024c-1d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
raise RuntimeError(f"Unexpected current_layout '{current_layout}'")
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,17 @@
from tvm.contrib.hexagon.build import HexagonLauncher
from tvm.topi import testing

from .infrastructure import allocate_hexagon_array
from ..infrastructure import allocate_hexagon_array, transform_numpy


def n11c_1024c_1d(n, h, w, c):
def nhwc_1024c_1d(n, h, w, c):
return [n, h, w, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024]


def nc_1024_1d(n, c):
return [n, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024]


def transform_numpy(arr_np, layout):
if layout == "nhwc":
return arr_np
elif layout == "n11c-1024c-1d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])
elif layout == "nc-1d":
N, C = arr_np.shape
return arr_np.reshape([N, C // 1024, 1024])


@tvm.testing.fixture
def transformed_expected_output_np(expected_output_np, output_layout):
return transform_numpy(expected_output_np, output_layout)


class BaseTestBatchFlatten:
input_shape = tvm.testing.parameter(
(1, 1, 1, 2048),
Expand All @@ -60,7 +44,7 @@ class BaseTestBatchFlatten:
(2, 4, 8, 1024),
(2, 3, 5, 2048),
)
input_layout, input_axis_sep = tvm.testing.parameters(("n11c-1024c-1d", [4]))
input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-1d", [4]))
output_layout, output_axis_sep = tvm.testing.parameters(("nc-1d", [2]))
data_type = tvm.testing.parameter("float16")

Expand Down Expand Up @@ -90,7 +74,7 @@ def test_batch_flatten(
D,
A,
nc_1024_1d,
n11c_1024c_1d,
nhwc_1024c_1d,
)
func_name = "batch_flatten"
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}):
Expand All @@ -101,8 +85,8 @@ def test_batch_flatten(
a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
ref = np.reshape(a_numpy, output_shape)

input_np_transformed = transform_numpy(a_numpy, input_layout)
ref_np_transformed = transform_numpy(ref, output_layout)
input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)

a_tvm = allocate_hexagon_array(
hexagon_session.device,
Expand Down

0 comments on commit 9338ce3

Please sign in to comment.