Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Runtime] Change default alignment to 64 bytes. (apache#12586)
Browse files Browse the repository at this point in the history
* Change default alignment to 64 bits.

* Run dlpack test a few times.

* Update alignment in tests.

* Revert mma alignment change.

* Change default printing of buffer.

* Change crt runtime default allocation.
  • Loading branch information
Josh Fromm authored and xinetzone committed Nov 25, 2022
1 parent e1ba316 commit 02682e4
Show file tree
Hide file tree
Showing 24 changed files with 303 additions and 299 deletions.
4 changes: 2 additions & 2 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ enum DeviceAttrKind : int {
};

/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 128;
constexpr int kAllocAlignment = 64;

/*! \brief Number of bytes each allocation must align to in temporary allocation */
constexpr int kTempAllocaAlignment = 128;
constexpr int kTempAllocaAlignment = 64;

/*! \brief Maximum size that can be allocated on stack */
constexpr int kMaxStackAlloca = 1024;
Expand Down
54 changes: 26 additions & 28 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared_handle,
shmem_shape,
dtype,
align=128,
align=64,
offset_factor=16,
scope=shared_scope,
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
warp_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand All @@ -149,13 +149,13 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared_handle,
shmem_shape,
dtype,
align=128,
align=64,
offset_factor=16,
scope=shared_scope,
strides=[s0, s1],
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
warp_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand Down Expand Up @@ -222,13 +222,13 @@ def maybe_swap(i, j):
@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
B = T.match_buffer(
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
C = T.match_buffer(
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
c, (WARP_SIZE, local_size_out), out_dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand Down Expand Up @@ -262,13 +262,13 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
B = T.match_buffer(
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
C = T.match_buffer(
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
c, (WARP_SIZE, local_size_out), out_dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand Down Expand Up @@ -510,11 +510,9 @@ def get_wmma_load_intrin(

@T.prim_func
def wmma_load_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=shared_scope
)
A = T.match_buffer(a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=shared_scope)
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=wmma_fragment_scope
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=wmma_fragment_scope
)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
Expand All @@ -532,13 +530,13 @@ def wmma_load_impl(a: T.handle, c: T.handle) -> None:
a,
(m_dim, n_dim),
dtype,
align=128,
align=64,
offset_factor=16,
scope=shared_scope,
strides=[s1, s0],
)
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=wmma_fragment_scope
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=wmma_fragment_scope
)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
Expand Down Expand Up @@ -569,7 +567,7 @@ def get_wmma_fill_intrin(
@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads()
Expand All @@ -582,7 +580,7 @@ def wmma_fill_desc(c: T.handle) -> None:
@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads()
Expand Down Expand Up @@ -610,9 +608,9 @@ def get_wmma_store_intrin(
@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=scope)
C = T.match_buffer(c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
T.writes(C[0:m_dim, 0:n_dim])
Expand All @@ -626,10 +624,10 @@ def wmma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=scope, strides=[s1, s0]
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope, strides=[s1, s0]
)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
Expand Down Expand Up @@ -671,18 +669,18 @@ def maybe_swap(i, j):
@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_a"
a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16, scope="wmma.matrix_a"
)
B = T.match_buffer(
b,
maybe_swap(k_dim, n_dim),
in_dtype,
align=128,
align=64,
offset_factor=16,
scope="wmma.matrix_b",
)
C = T.match_buffer(
c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)

with T.block("root"):
Expand All @@ -699,18 +697,18 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_a"
a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16, scope="wmma.matrix_a"
)
B = T.match_buffer(
b,
maybe_swap(k_dim, n_dim),
in_dtype,
align=128,
align=64,
offset_factor=16,
scope="wmma.matrix_b",
)
C = T.match_buffer(
c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)

with T.block("root"):
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
if (GetRef<Buffer>(buf).scope() != "global") {
doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope());
}
if (buf->data_alignment != 128) {
if (buf->data_alignment != runtime::kAllocAlignment) {
doc << ", align=" << buf->data_alignment;
}
if (buf->offset_factor != 1) {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shap
}
nbytes *= (dtype.bits * dtype.lanes + 7) / 8;

int kAllocAlignment = 128;
int kAllocAlignment = 64;
size_t align = (dtype.bits / 8) * dtype.lanes;
if (align < kAllocAlignment) align = kAllocAlignment;
return TVMDeviceAllocDataSpace(dev, nbytes, align, dtype, out_data);
Expand Down
10 changes: 8 additions & 2 deletions tests/python/contrib/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.contrib.dlpack import to_pytorch_func


def test():
def verify_torch_dlpack():
a = np.random.randn(1337)
tvm_a = tvm.nd.array(a)
np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).numpy(), a)
Expand Down Expand Up @@ -63,5 +63,11 @@ def test():
pass


def test_torch_dlpack():
# Run dlpack interoperability test a few times to make sure it's stable.
for i in range(5):
verify_torch_dlpack()


if __name__ == "__main__":
test()
test_torch_dlpack()
Loading

0 comments on commit 02682e4

Please sign in to comment.