Skip to content

Commit

Permalink
Find root by traversing the compute ops in reverse. (#10210)
Browse files Browse the repository at this point in the history
Since most of the codegeneration uses tile + fuse, where the consumer
is tiled and the producer is fused with it, find the root by
traversing the ops in reverse.

Issue #10208
  • Loading branch information
MaheshRavishankar authored Aug 26, 2022
1 parent 272ea37 commit cf5a5d5
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
21 changes: 7 additions & 14 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1318,37 +1318,30 @@ static LogicalResult setVMVXRootConfigImpl(func::FuncOp entryPointFn,
static FailureOr<Operation *> getRootOperation(
ArrayRef<Operation *> computeOps) {
Operation *rootOperation = nullptr;
auto updateRootOperation = [&](Operation *op) -> LogicalResult {
if (rootOperation) {
return op->emitOpError(
"unhandled multiple root operations in dispatch region");
}
rootOperation = op;
return success();
};
for (auto op : computeOps) {
for (auto op : llvm::reverse(computeOps)) {
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
// Do not not treat linalg ops that are all parallel as root operations in
// this sweep.
if (linalgOp.getNumLoops() == linalgOp.getNumParallelLoops()) continue;

// All other linalg ops are root ops.
if (failed(updateRootOperation(op))) return failure();
continue;
rootOperation = op;
break;
}

if (isa<TilingInterface>(op)) {
// All other operations that implement this interface are root ops.
if (failed(updateRootOperation(op))) return failure();
continue;
rootOperation = op;
break;
}
}
if (rootOperation) return rootOperation;

// If no root operation is found yet. Look for linalg generic ops.
for (auto op : llvm::reverse(computeOps)) {
if (isa<linalg::LinalgOp>(op)) {
if (failed(updateRootOperation(op))) return failure();
rootOperation = op;
break;
}
}
return rootOperation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1270,3 +1270,71 @@ hal.executable private @transpose_8x8 {

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [8, 8], []{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>

// -----

hal.executable private @multi_root {
hal.executable.variant public @embedded_elf_x86_64, target = <"llvm-cpu", "embedded-elf-x86_64", {
cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}> {
hal.executable.export public @multi_root ordinal(0)
layout(#hal.pipeline.layout<
push_constants = 0,
sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
%x, %y, %z = flow.dispatch.default_workgroup_count %arg1, %arg2, %arg3
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @multi_root() {
%c0 = arith.constant 0 : index
%c6144 = arith.constant 6144 : index
%c792576 = arith.constant 792576 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64)
: !flow.dispatch.tensor<readonly:12x128x128xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64)
: !flow.dispatch.tensor<readonly:12x128xf32>
%3 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c792576) alignment(64)
: !flow.dispatch.tensor<writeonly:12x128xf32>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 128], strides = [1, 1, 1]
: !flow.dispatch.tensor<readonly:12x128x128xf32> -> tensor<12x128x128xf32>
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [12, 128], strides = [1, 1]
: !flow.dispatch.tensor<readonly:12x128xf32> -> tensor<12x128xf32>
%7 = linalg.init_tensor [12, 128] : tensor<12x128xf32>
%8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<12x128xf32>) -> tensor<12x128xf32>
%9 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%4 : tensor<12x128x128xf32>) outs(%5 : tensor<12x128xf32>) {
^bb0(%arg0: f32, %arg1: f32):
%11 = arith.maxf %arg0, %arg1 : f32
linalg.yield %11 : f32
} -> tensor<12x128xf32>
%10 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%4, %9 : tensor<12x128x128xf32>, tensor<12x128xf32>)
outs(%8 : tensor<12x128xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
%11 = arith.subf %arg0, %arg1 : f32
%12 = math.exp %11 : f32
%13 = arith.addf %12, %arg3 : f32
linalg.yield %13 : f32
} -> tensor<12x128xf32>
flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [12, 128], strides = [1, 1]
: tensor<12x128xf32> -> !flow.dispatch.tensor<writeonly:12x128xf32>
return
}
}
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[4, 32, 0], [1, 4, 0], [0, 0, 4]{{\]}}
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
// CHECK: hal.executable.export public @multi_root
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.generic
// CHECK-NOT: lowering_config
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[CONFIG]]

0 comments on commit cf5a5d5

Please sign in to comment.