Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Add hexagon user DMA intrins for tensorization #13719

Merged
merged 4 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,54 @@
from .. import TensorIntrin


def generate_dma_load_intrin(
size: int,
dtype: str,
):
"""Generator of dma_load intrins"""

@T.prim_func
def sync_dma_load_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (size), dtype, offset_factor=1, scope="global")
C = T.match_buffer(c, (size), dtype, offset_factor=1, scope="global.vtcm")
with T.block("root"):
T.reads(A[0:size])
T.writes(C[0:size])
for i in T.serial(size):
with T.block("load"):
vii = T.axis.remap("S", [i])
C[vii] = A[vii]

@T.prim_func
def sync_dma_load_impl(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (size), dtype, offset_factor=1, scope="global")
C = T.match_buffer(c, (size), dtype, offset_factor=1, scope="global.vtcm")
with T.block("root"):
T.reads(A[0:size])
T.writes(C[0:size])
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_copy",
-1, # Use QueueId of -1 to not interfere with async copies.
T.address_of(C[0], dtype="handle"),
T.address_of(A[0], dtype="handle"),
size,
0, # Do not use experimental bypass mode.
dtype="int32",
)
)
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_wait",
-1,
0, # Wait for the sync queue (-1) to have 0 messages.
dtype="int32",
)
)

return sync_dma_load_desc, sync_dma_load_impl


def generate_dot_product_32x4_u8u8i32(mem_scope="global"):
@T.prim_func
def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
Expand Down Expand Up @@ -163,3 +211,9 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N

VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy"
TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8i8i32("global.vtcm"))

DMA_READ_128_u8 = "dma_read_128_u8"
TensorIntrin.register(DMA_READ_128_u8, *generate_dma_load_intrin(128, "uint8"))

DMA_READ_128_i8 = "dma_read_128_i8"
TensorIntrin.register(DMA_READ_128_i8, *generate_dma_load_intrin(128, "int8"))
46 changes: 27 additions & 19 deletions tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import numpy as np

import tvm
import pytest
from tvm.script import tir as T
from tvm.tir.tensor_intrin.hexagon import DMA_READ_128_i8

from .infrastructure import get_hexagon_target

Expand All @@ -30,6 +32,7 @@
"Test bandwidth with buffer size {}MB... \n"
" -Base: {} GBps \n -Vectorized: {} GBps\n"
" -Vectorized and Parallelized: {} GBps\n"
" -Sync DMA: {} GBps\n"
" -Single DMA Copy: {} GBps\n"
)

Expand Down Expand Up @@ -103,13 +106,12 @@ def evaluate(hexagon_session, sch, size):
a_vtcm, device=hexagon_session.device, mem_scope="global.vtcm"
)

# These are reduced for CI but number=100 and repeat=10 does a good job of removing noise.
number = 1
repeat = 1
if tvm.testing.utils.IS_IN_CI:
# Run with reduced number and repeat for CI
timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=1, repeat=1)
else:
timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10)

timer = module.time_evaluator(
"__tvm_main__", hexagon_session.device, number=number, repeat=repeat
)
runtime = timer(a_hexagon, a_vtcm_hexagon)

gbps = round((size / 2**30) / runtime.mean, 4)
Expand All @@ -123,18 +125,11 @@ class TestMatMulVec:

# Removed most of these to speedup CI.
size = tvm.testing.parameter(
# 10 * KB,
# 20 * KB,
# 40 * KB,
# 80 * KB,
# 160 * KB,
# 320 * KB,
640 * KB,
# MB,
# 2 * MB,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to uncomment this? Makes the test run longer in CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a check if running in CI.

# 3 * MB,
# 4 * MB,
# 8 * MB, # Only works on 8gen1 HDKs
128,
KB,
10 * KB,
100 * KB,
MB,
)

outer_split = tvm.testing.parameter(4)
Expand All @@ -144,6 +139,10 @@ class TestMatMulVec:
@tvm.testing.requires_hexagon
def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vector_split):
"""Test bandwidth."""

if tvm.testing.utils.IS_IN_CI and (size > 128):
pytest.skip("Skipping test since it takes too long in CI.")

# Run the base memcopy operator.
sch = tvm.tir.Schedule(memcopy_operator(size))
base_gpbs = evaluate(hexagon_session, sch, size)
Expand All @@ -169,14 +168,23 @@ def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vecto
sch.parallel(vbo_a)
parallel_gbps = evaluate(hexagon_session, sch, size)

# Run with some basic unroll and vectorize scheduling and parallelization.
sch = tvm.tir.Schedule(memcopy_operator(size))
block = sch.get_block("A_global.vtcm")
loops = sch.get_loops(block)
_, inner = sch.split(loops[0], [None, 128])
sch.tensorize(inner, DMA_READ_128_i8)
# print(sch.mod.script())
sync_dma_gbps = evaluate(hexagon_session, sch, size)

# Run using a single dma copy to transfer the data.
sch = tvm.tir.Schedule(single_dma_operator(size))
single_dma_gbps = evaluate(hexagon_session, sch, size)

mbs = round(size / MB, 2)
print(
TEST_OUTPUT_TEMPLATE.format(
mbs, base_gpbs, vectorize_gbps, parallel_gbps, single_dma_gbps
mbs, base_gpbs, vectorize_gbps, parallel_gbps, sync_dma_gbps, single_dma_gbps
)
)

Expand Down