Skip to content

Commit

Permalink
[LLVMGPU][TD] Don't apply the unalgined strategy for unsupported cases (
Browse files Browse the repository at this point in the history
#13450)

We currently have some limitations with respect to the optionality of
pads in the current transform dialect strategy. As a result make sure we
don't apply this strategy when some of the pads may be folded away.

The existing code was already checking for that, but the condition was
slightly off.

Fix that, i.e., only apply the unaligned strategy when both M and K or N
and K are unaligned.

This fixes #13448
  • Loading branch information
qcolombet authored May 8, 2023
1 parent a058f29 commit 32d7e10
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,38 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// generalization along this axis.
// CHECK-NOT: transform.sequence

// -----
hal.executable @matmul_parially_unaligned {
hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
hal.executable.export public @matmul ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @matmul_parially_unaligned() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2048x2044xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2044x1024xf32>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2048x1024xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x2044xf32>> -> tensor<2048x2044xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2044x1024xf32>> -> tensor<2044x1024xf32>
%5 = tensor.empty() : tensor<2048x1024xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2048x1024xf32>) -> tensor<2048x1024xf32>
%7 = linalg.matmul ins(%3, %4 : tensor<2048x2044xf32>, tensor<2044x1024xf32>) outs(%6 : tensor<2048x1024xf32>) -> tensor<2048x1024xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 1024], strides = [1, 1] : tensor<2048x1024xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048x1024xf32>>
return
}
}
}
}

// CHECK-LABEL: func @matmul_parially_unaligned

// "Enough" of this matmul's dimensions are divisible by 64/64/16.
// We currently bail on such cases because at least one of the paddings involved
// in the strategy fold away and result in the strategy failing to apply.
// In the future we should also support this case but for now we are missing the
// generalization along this axis.
// CHECK-NOT: transform.sequence
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,11 @@ static LogicalResult matchAndSetMatmulStrategy(func::FuncOp entryPoint,
// - n and k are not aligned to the tile sizes (conservatively, take 64, 16)
// Other cases currently result in folding and fall back to the default
// unaligned IREE strategy.
bool unsupportedAlignedCases =
(matmulSize[0] % 64 == 0 && matmulSize[2] % 16 == 0) ||
(matmulSize[1] % 64 == 0 && matmulSize[2] % 16 == 0);
if (unsupportedAlignedCases) {
bool supportedUnalignedCases =
(matmulSize[0] % 64 != 0 && matmulSize[2] % 16 != 0) ||
(matmulSize[1] % 64 != 0 && matmulSize[2] % 16 != 0);

if (!supportedUnalignedCases) {
LDBG("--Matmul strategy alignment check failed\n");
return failure();
}
Expand Down

0 comments on commit 32d7e10

Please sign in to comment.