Skip to content

Commit

Permalink
[HEXAGON][TOPI]Slice Op Argmax uint8 (#12472)
Browse files Browse the repository at this point in the history
  • Loading branch information
arangasa authored Sep 1, 2022
1 parent 54786bb commit 50dad0d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
7 changes: 7 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,11 @@ def argmax_schedule(argmax_func, in_layout_str, out_layout_str):
argmax_func, fp16_layout_transform, int32_layout_transform
)
return tir_s
if (in_layout_str == "nhwc-8h8w32c-2d") and (out_layout_str == "nhw-32h16w-2d"):
int8_layout_transform = get_layout_transform_fn(in_layout_str)
int32_layout_transform = get_layout_transform_fn(out_layout_str)
tir_s = argmax_stir_schedule_nhwc(
argmax_func, int8_layout_transform, int32_layout_transform
)
return tir_s
raise RuntimeError(f"Unexpected input_layout, output_layout '{in_layout_str, out_layout_str}'")
14 changes: 8 additions & 6 deletions tests/python/contrib/test_hexagon/topi/test_argmax_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
""" Tests for Hexagon slice argmax op """
import pytest
import numpy as np

import tvm
Expand All @@ -33,15 +32,18 @@ class TestArgMaxSlice:
input_shape,
input_layout,
output_layout,
dtype,
in_axis,
in_axis_sep,
out_axis_sep,
) = tvm.testing.parameters(
((1, 64, 64, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]),
((3, 32, 16, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]),
((1, 32, 32, 64), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]),
((1, 64, 64, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", "float16", [3], [4], [3]),
((3, 32, 16, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", "float16", [3], [4], [3]),
((1, 32, 32, 64), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", "float16", [3], [4], [3]),
((1, 64, 64, 32), "nhwc-8h8w32c-2d", "nhw-32h16w-2d", "int8", [3], [4], [3]),
((3, 32, 16, 32), "nhwc-8h8w32c-2d", "nhw-32h16w-2d", "int8", [3], [4], [3]),
((1, 32, 32, 64), "nhwc-8h8w32c-2d", "nhw-32h16w-2d", "int8", [3], [4], [3]),
)
dtype = tvm.testing.parameter("float16")
working_scope = tvm.testing.parameter("global.vtcm")

@tvm.testing.fixture
Expand Down Expand Up @@ -96,7 +98,7 @@ def test_argmax_slice(
axis_separators=out_axis_sep,
mem_scope=working_scope,
)
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}):
with tvm.transform.PassContext(opt_level=3):
tir_irm = tvm.lower(tir_s.mod, [argmax_input, output], name="argmax")
runtime_module = tvm.build(
tir_irm, [argmax_input, output], target=target, name="argmax"
Expand Down

0 comments on commit 50dad0d

Please sign in to comment.