Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][DLight] Enable SimdGroup op for Metal #17112

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ TVM_DLL const Op& create_barriers();
TVM_DLL const Op& mma_store();

/*!
* \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
* \brief tvm intrinsic for zero-initializing an MMA accumulation register.
* For example, if each thread in a warp of size 32 has 8 elements from the A matrix in
* m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its
* 4 accumulation registers.
Expand All @@ -758,6 +758,48 @@ TVM_DLL const Op& mma_store();
*/
TVM_DLL const Op& mma_fill();

// Metal SimdGroup matrix intrinsics

/*!
* \brief tvm intrinsic for initializing and simdgroup with given value.
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void make_filled_simdgroup_matrix(Var d, PrimExpr index, PrimExpr value,
* int col = 8, int row = 8);
*/
TVM_DLL const Op& make_filled_simdgroup_matrix();

/*!
* \brief tvm intrinsic for loading data from device memory or threadgroup memory to simdgroup.
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void simdgroup_load(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
int col = 8, int row = 8, bool transpose_matrix = false);
*/
TVM_DLL const Op& simdgroup_load();

/*!
* \brief tvm intrinsic for storing data from simdgroup to device memory or threadgroup memory.
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void simdgroup_store(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
* int col = 8, int row = 8, bool transpose_matrix = false);
*/
TVM_DLL const Op& simdgroup_store();

/*!
* \brief tvm intrinsic for multiply and accumulate two matrices in simdgroup
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
* keeping the similar interface with Metal Spec.
*
* void simdgroup_mma(Var d, PrimExpr index_d, Var a, PrimExpr index_a,
* Var b, PrimExpr index_b, Var c, PrimExpr index_c);
*/
TVM_DLL const Op& simdgroup_multiply_accumulate();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
145 changes: 145 additions & 0 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,146 @@ def check_sm_version(arch: str) -> int:
return int(sm_version) if sm_version.isdigit() else -1


class MetalMatmul(GPUScheduleRule):
"""
The schedule rule for Metal matmul computation.
"""

def apply( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
from tvm.tir.tensor_intrin.metal import ( # pylint: disable=import-outside-toplevel
get_simdgroup_intrin_group,
)

if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

reduction_blocks = get_reduction_blocks(sch, blocks)
if reduction_blocks is None:
return None

main_block = reduction_blocks[0]
block_stmt = sch.get(main_block)
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

# Step 0. Configs
block_size_x: int = 16
block_size_y: int = 16
block_size_k: int = 32
micro_size: int = 8
warp_size: int = 32
ty_len: int = 1
tz_len: int = 4
vector_size: int = 4

# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)

# Step 2. Padding for dynamic shape kernels
sch.pad_einsum(
main_block,
[
1,
ty_len * block_size_x,
tz_len * block_size_y,
block_size_k,
],
)

# Step 3. Schedule matmul to use simdgroup intrinsics
batch, i, j, k = sch.get_loops(main_block)
bx, ty, i0, i1 = sch.split(i, [None, ty_len, block_size_x // micro_size, micro_size])
by, tz, j0, j1 = sch.split(j, [None, tz_len, block_size_y // micro_size, micro_size])
k0, k1, k2 = sch.split(k, [None, block_size_k // micro_size, micro_size])
sch.reorder(bx, by, ty, tz, k0, k1, i0, j0, i1, j1, k2)
sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")
sch.bind(batch, "blockIdx.z")
sch.bind(ty, "threadIdx.y")
sch.bind(tz, "threadIdx.z")

def fetch_to_shared(block, idx):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, k0, preserve_unit_loops=True)
fused = sch.fuse(*sch.get_loops(block_read)[-2:])
_, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size])

sch.bind(_tz, "threadIdx.z")
sch.bind(_ty, "threadIdx.y")
sch.bind(_tx, "threadIdx.x")
sch.vectorize(vec)

return block_read

a_g2s = fetch_to_shared(main_block, 0)
b_g2s = fetch_to_shared(main_block, 1)

auto_inline_producers(sch, a_g2s)
auto_inline_producers(sch, b_g2s)

# create read cache to load matrix from shared memory to wmma fragments
A_simdgroup = sch.cache_read(main_block, 0, "metal.simdgroup")
B_simdgroup = sch.cache_read(main_block, 1, "metal.simdgroup")
sch.compute_at(A_simdgroup, k1)
sch.compute_at(B_simdgroup, k1)

C_simd2s = sch.cache_write(main_block, 0, "metal.simdgroup")
C_s2g = sch.cache_write(C_simd2s, 0, "shared")
sch.reverse_compute_at(C_simd2s, tz, preserve_unit_loops=True)
sch.reverse_compute_at(C_s2g, by, preserve_unit_loops=True)

intrin_group = get_simdgroup_intrin_group(
load_scope="shared",
store_scope="shared",
dtype="float16",
trans_a=False,
trans_b=True,
)
sch.transform_layout(B_simdgroup, ("write", 0), lambda s, i, j: (s, j, i))

def tensorize_block(block: tir.schedule.BlockRV, intrin: str):
*_, i, j = sch.get_loops(block)
io, ii = sch.split(i, [None, micro_size])
jo, ji = sch.split(j, [None, micro_size])
sch.reorder(io, jo, ii, ji)
sch.tensorize(ii, intrin)

C_init = sch.decompose_reduction(main_block, k0)
tensorize_block(A_simdgroup, intrin_group["load_a"])
tensorize_block(B_simdgroup, intrin_group["load_b"])
tensorize_block(C_simd2s, intrin_group["store"])
tensorize_block(C_init, intrin_group["init"])

*_, i, j, k = sch.get_loops(main_block)
sch.tensorize(i, intrin_group["compute"])

auto_inline_consumer_chain(sch, C_s2g)
fused = sch.fuse(*sch.get_loops(C_s2g)[-2:])
_, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size])
sch.bind(_tz, "threadIdx.z")
sch.bind(_ty, "threadIdx.y")
sch.bind(_tx, "threadIdx.x")
sch.vectorize(vec)

return sch


class MatmulTensorization(GPUScheduleRule):
"""
The schedule rule for float16 tensor core matmul computation.
Expand Down Expand Up @@ -848,6 +988,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
tensorize_sch = MatmulTensorization().apply(func, target, _)
if tensorize_sch is not None:
return tensorize_sch
elif target.kind.name == "metal":
try:
return MetalMatmul().apply(func, target, _)
except: # pylint: disable=bare-except
pass

# Step 2. Get schedule config.
config = self.get_configs(target)
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,10 @@ def wrapped(*args, **kwargs):
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
make_filled_simdgroup_matrix = _op_wrapper(_tir_op.make_filled_simdgroup_matrix)
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
Expand Down Expand Up @@ -2177,6 +2181,10 @@ def wrapped(*args, **kwargs):
"ptx_arrive_barrier",
"ptx_arrive_barrier_expect_tx",
"ptx_wait_barrier",
"make_filled_simdgroup_matrix",
"simdgroup_load",
"simdgroup_store",
"simdgroup_multiply_accumulate",
"create_barriers",
"mma_store",
"mma_fill",
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
ptx_wait_barrier,
create_barriers,
)
from .op import (
make_filled_simdgroup_matrix,
simdgroup_load,
simdgroup_multiply_accumulate,
simdgroup_store,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
Loading
Loading