Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime] Change default alignment to 64 bytes. #12586

Merged
merged 6 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ enum DeviceAttrKind : int {
};

/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 128;
constexpr int kAllocAlignment = 64;

/*! \brief Number of bytes each allocation must align to in temporary allocation */
constexpr int kTempAllocaAlignment = 128;
constexpr int kTempAllocaAlignment = 64;

/*! \brief Maximum size that can be allocated on stack */
constexpr int kMaxStackAlloca = 1024;
Expand Down
54 changes: 26 additions & 28 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared_handle,
shmem_shape,
dtype,
align=128,
align=64,
offset_factor=16,
scope=shared_scope,
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
warp_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand All @@ -149,13 +149,13 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared_handle,
shmem_shape,
dtype,
align=128,
align=64,
offset_factor=16,
scope=shared_scope,
strides=[s0, s1],
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
warp_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand Down Expand Up @@ -222,13 +222,13 @@ def maybe_swap(i, j):
@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
B = T.match_buffer(
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
C = T.match_buffer(
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
c, (WARP_SIZE, local_size_out), out_dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand Down Expand Up @@ -262,13 +262,13 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
B = T.match_buffer(
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, scope="warp"
)
C = T.match_buffer(
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
c, (WARP_SIZE, local_size_out), out_dtype, align=64, offset_factor=16, scope="warp"
)

with T.block("root"):
Expand Down Expand Up @@ -510,11 +510,9 @@ def get_wmma_load_intrin(

@T.prim_func
def wmma_load_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=shared_scope
)
A = T.match_buffer(a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=shared_scope)
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=wmma_fragment_scope
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=wmma_fragment_scope
)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
Expand All @@ -532,13 +530,13 @@ def wmma_load_impl(a: T.handle, c: T.handle) -> None:
a,
(m_dim, n_dim),
dtype,
align=128,
align=64,
offset_factor=16,
scope=shared_scope,
strides=[s1, s0],
)
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=wmma_fragment_scope
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=wmma_fragment_scope
)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
Expand Down Expand Up @@ -569,7 +567,7 @@ def get_wmma_fill_intrin(
@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads()
Expand All @@ -582,7 +580,7 @@ def wmma_fill_desc(c: T.handle) -> None:
@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads()
Expand Down Expand Up @@ -610,9 +608,9 @@ def get_wmma_store_intrin(
@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=scope)
C = T.match_buffer(c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
T.writes(C[0:m_dim, 0:n_dim])
Expand All @@ -626,10 +624,10 @@ def wmma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator"
a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(
c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=scope, strides=[s1, s0]
c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope, strides=[s1, s0]
)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
Expand Down Expand Up @@ -671,18 +669,18 @@ def maybe_swap(i, j):
@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_a"
a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16, scope="wmma.matrix_a"
)
B = T.match_buffer(
b,
maybe_swap(k_dim, n_dim),
in_dtype,
align=128,
align=64,
offset_factor=16,
scope="wmma.matrix_b",
)
C = T.match_buffer(
c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)

with T.block("root"):
Expand All @@ -699,18 +697,18 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_a"
a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16, scope="wmma.matrix_a"
)
B = T.match_buffer(
b,
maybe_swap(k_dim, n_dim),
in_dtype,
align=128,
align=64,
offset_factor=16,
scope="wmma.matrix_b",
)
C = T.match_buffer(
c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, scope="wmma.accumulator"
c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16, scope="wmma.accumulator"
)

with T.block("root"):
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
if (GetRef<Buffer>(buf).scope() != "global") {
doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope());
}
if (buf->data_alignment != 128) {
if (buf->data_alignment != runtime::kAllocAlignment) {
doc << ", align=" << buf->data_alignment;
}
if (buf->offset_factor != 1) {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shap
}
nbytes *= (dtype.bits * dtype.lanes + 7) / 8;

int kAllocAlignment = 128;
int kAllocAlignment = 64;
size_t align = (dtype.bits / 8) * dtype.lanes;
if (align < kAllocAlignment) align = kAllocAlignment;
return TVMDeviceAllocDataSpace(dev, nbytes, align, dtype, out_data);
Expand Down
10 changes: 8 additions & 2 deletions tests/python/contrib/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.contrib.dlpack import to_pytorch_func


def test():
def verify_torch_dlpack():
a = np.random.randn(1337)
tvm_a = tvm.nd.array(a)
np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).numpy(), a)
Expand Down Expand Up @@ -63,5 +63,11 @@ def test():
pass


def test_torch_dlpack():
# Run dlpack interoperability test a few times to make sure it's stable.
for i in range(5):
verify_torch_dlpack()


if __name__ == "__main__":
test()
test_torch_dlpack()
Loading