Skip to content

Commit

Permalink
Initial affinity assignment plumbing through flow/stream/hal. (#10833)
Browse files Browse the repository at this point in the history
The stream affinity attr is changed to an interface and a placeholder
`#hal.affinity.queue<...>` implementation is defined to support
single-device multi-queue assignment. There's still a decent number of
things that will need to change for heterogeneous devices where we'll
want affinities to specify `#hal.device.target` values but for now
queues are enough. The queues specified are carried all the way to the
runtime device queue affinity masks.

Example:
```mlir
// run only on queue 0
%0 = flow.dispatch @ex::@entry0(%input) {stream.affinity = #hal.affinity.queue<[0]>} : ...
// run on either queue 1 or 2
%1 = flow.dispatch @ex::@Entry1(%input) {stream.affinity = #hal.affinity.queue<[1, 2]>} : ...
// run on any queue (attr can also be omitted)
%2 = flow.dispatch @ex::@entry2(%0, %1) {stream.affinity = #hal.affinity.queue<*>} : ...
```
->
```mlir
hal.device.queue.execute<%device : !hal.device> affinity(%c1)   // 0b001
hal.device.queue.execute<%device : !hal.device> affinity(%c6)   // 0b110
hal.device.queue.execute<%device : !hal.device> affinity(%c-1)  // 0b111...
```

Currently the affinities are picked up starting from the
`flow.dispatch.region`/`flow.dispatch.workgroups` ops. To specify
affinities before then the various dispatch region formation code (I
guess 3 different ways now?) will need to be updated to respect user
affinity specification (only fuse compatible things) and put the
affinity attribute on the resulting dispatch region.

Affinities attributes must implement the interface and as such are at
this layer IREE-specific. Before dispatch region formation (or as part
of it) anything else could be used, with the only requirement being that
before the stream dialect transformation pipeline is entered they are
all turned in to `stream.affinity` attributes. For example a frontend
could say `my.device_placement = "please run this on queue 4 🙇"` so long
as it can later be mapped to `stream.affinity =
#hal.affinity.queue<[4]>`.

Progress on #10765.
  • Loading branch information
benvanik authored Oct 24, 2022
2 parents e46d277 + 962c073 commit 87c5f3f
Show file tree
Hide file tree
Showing 29 changed files with 575 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
auto workgroupsOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>(
loc, workload, regionOp.getResultTypes(), regionOp.getResultDims(),
arguments, argumentDims, tiedArguments);
workgroupsOp->setDialectAttrs(regionOp->getDialectAttrs());
BlockAndValueMapping bvm;
bvm.map(arguments, workgroupsOp.getInputBlockArguments());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ static LogicalResult convertToDispatchOp(DispatchWorkgroupsOp regionOp,
regionOp.getResultTypes(), regionOp.getResultDims(),
regionOp.getArguments(), regionOp.getArgumentDims(),
regionOp.getTiedOperandsAttr());
dispatchOp->setDialectAttrs(regionOp->getDialectAttrs());

// Replace uses of the existing results with the new results.
for (int i = 0; i < regionOp.getNumResults(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ iree_lit_test_suite(
"collapse_reduction.mlir",
"conv1x1_to_matmul.mlir",
"conv2d_to_img2col.mlir",
"convert_region_to_workgroups.mlir",
"deduplicate_executables.mlir",
"detach_elementwise_from_named_ops.mlir",
"dispatch_linalg_on_tensors.mlir",
Expand All @@ -41,7 +42,6 @@ iree_lit_test_suite(
"outline_dispatch_regions.mlir",
"pad_linalg_ops.mlir",
"tensor_pad_to_tensor_insert_slice.mlir",
"region_to_workgroups.mlir",
"strip_and_splat_constant_variables.mlir",
"strip_signedness.mlir",
"transform_dispatch_region_formation.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"collapse_reduction.mlir"
"conv1x1_to_matmul.mlir"
"conv2d_to_img2col.mlir"
"convert_region_to_workgroups.mlir"
"deduplicate_executables.mlir"
"detach_elementwise_from_named_ops.mlir"
"dispatch_linalg_on_tensors.mlir"
Expand All @@ -38,7 +39,6 @@ iree_lit_test_suite(
"optimize_numerics.mlir"
"outline_dispatch_regions.mlir"
"pad_linalg_ops.mlir"
"region_to_workgroups.mlir"
"strip_and_splat_constant_variables.mlir"
"strip_signedness.mlir"
"tensor_pad_to_tensor_insert_slice.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func.func @foo(%argA: tensor<?x?xf32>, %argB: tensor<5x10xf32>, %argC: tensor<10
flow.return %argA : tensor<?x?xf32>
}

// CHECK: %[[r1:.*]] = flow.dispatch.workgroups(%[[argB]], %[[argC]]) : (tensor<5x10xf32>, tensor<10x11xf32>) -> tensor<5x11xf32> =
// CHECK: %[[r1:.*]] = flow.dispatch.workgroups(%[[argB]], %[[argC]]) : (tensor<5x10xf32>, tensor<10x11xf32>) -> tensor<5x11xf32>
// CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]>
// CHECK-NEXT: (%[[arg3:.*]]: !flow.dispatch.tensor<readonly:5x10xf32>, %[[arg4:.*]]: !flow.dispatch.tensor<readonly:10x11xf32>, %[[arg5:.*]]: !flow.dispatch.tensor<writeonly:5x11xf32>)
// CHECK-DAG: %[[loadB:.*]] = flow.dispatch.tensor.load %[[arg3]], offsets = [0, 0], sizes = [5, 10], strides = [1, 1] : !flow.dispatch.tensor<readonly:5x10xf32> -> tensor<5x10xf32>
// CHECK-DAG: %[[loadC:.*]] = flow.dispatch.tensor.load %[[arg4]], offsets = [0, 0], sizes = [10, 11], strides = [1, 1] : !flow.dispatch.tensor<readonly:10x11xf32> -> tensor<10x11xf32>
Expand All @@ -31,7 +32,7 @@ func.func @foo(%argA: tensor<?x?xf32>, %argB: tensor<5x10xf32>, %argC: tensor<10
// CHECK: flow.dispatch.tensor.store %[[matmul]], %[[arg5]], offsets = [0, 0], sizes = [5, 11], strides = [1, 1] : tensor<5x11xf32> -> !flow.dispatch.tensor<writeonly:5x11xf32>
// CHECK: flow.return
// CHECK: }
%r1 = flow.dispatch.region -> (tensor<5x11xf32>) {
%r1 = flow.dispatch.region {stream.affinity = #hal.affinity.queue<[0]>} -> (tensor<5x11xf32>) {
%zero = arith.constant 0.0 : f32
%0 = tensor.empty() : tensor<5x11xf32>
%1 = linalg.fill ins(%zero : f32) outs(%0 : tensor<5x11xf32>) -> tensor<5x11xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ func.func @dispatchFn1(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> {
%x = arith.constant 100 : index
%y = arith.constant 50 : index
// CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = (
// CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]>
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes {
stream.affinity = #hal.affinity.queue<[0]>
} = (
%arg: !flow.dispatch.tensor<readonly:8x4xf32>, %ret: !flow.dispatch.tensor<writeonly:4x8xf32>
) {
flow.return
Expand All @@ -100,7 +103,10 @@ func.func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> {
%x = arith.constant 100 : index
%y = arith.constant 50 : index
// CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = (
// CHECK-SAME: stream.affinity = #hal.affinity.queue<[1]>
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes {
stream.affinity = #hal.affinity.queue<[1]>
} = (
%arg: !flow.dispatch.tensor<readonly:8x4xf32>, %ret: !flow.dispatch.tensor<writeonly:4x8xf32>
) {
flow.return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,39 @@ static Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
return lookupOp.getResult();
}

// Returns the device queue affinity mask indicating which device queues the
// operations are allowed to execute on.
static Value buildQueueAffinityMaskFor(Operation *op, Value device,
OpBuilder &builder) {
// Try to find a specified affinity. This may be on the op provided or one of
// its parent regions.
auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op);
if (auto queueAffinityAttr =
affinityAttr.dyn_cast_or_null<IREE::HAL::AffinityQueueAttr>()) {
return builder.create<arith::ConstantIntOp>(
op->getLoc(), queueAffinityAttr.getMask(), 64);
}

// No affinity specified; use default (any) affinity.
return builder.create<arith::ConstantIntOp>(op->getLoc(), -1, 64);
}

static std::tuple<Value, Value> lookupDeviceAndQueueAffinityFor(
Operation *op, OpBuilder &builder) {
// NOTE: we have this combined method so that we can reuse any expensive
// lookups we need to do. Today we aren't duplicating the lookups and don't
// bother.

// Get a device handle used to create resources and schedule work.
// It may be shared across many mutually-exclusive devices at runtime.
Value device = lookupDeviceFor(op, builder);

// Derive the queue affinity mask from the op and device combination.
Value queueAffinity = buildQueueAffinityMaskFor(op, device, builder);

return std::make_tuple(device, queueAffinity);
}

static Value lookupAllocatorFor(Operation *op, OpBuilder &builder) {
auto device = lookupDeviceFor(op, builder);
auto allocatorOp =
Expand Down Expand Up @@ -295,7 +328,8 @@ struct ResourceAllocaOpPattern
IREE::Stream::ResourceAllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = allocaOp.getLoc();
auto device = lookupDeviceFor(allocaOp, rewriter);
auto [device, queueAffinity] =
lookupDeviceAndQueueAffinityFor(allocaOp, rewriter);
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();

// Transient allocations are device-local. Copies are required to get their
Expand All @@ -316,7 +350,6 @@ struct ResourceAllocaOpPattern
loc, device, allocaOp.getResultTimepoint(), rewriter);

// Queue allocation.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
auto pool = rewriter.create<arith::ConstantIntOp>(loc, 0, 64);
auto allocateOp = rewriter.create<IREE::HAL::DeviceQueueAllocaOp>(
loc, bufferType, device, queueAffinity, waitFence, signalFence, pool,
Expand All @@ -334,7 +367,8 @@ struct ResourceDeallocaOpPattern
IREE::Stream::ResourceDeallocaOp deallocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = deallocaOp.getLoc();
auto device = lookupDeviceFor(deallocaOp, rewriter);
auto [device, queueAffinity] =
lookupDeviceAndQueueAffinityFor(deallocaOp, rewriter);

// Gather wait/signal fence, which are optional.
Value waitFence =
Expand All @@ -343,7 +377,6 @@ struct ResourceDeallocaOpPattern
loc, device, deallocaOp.getResultTimepoint(), rewriter);

// Queue allocation.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
rewriter.create<IREE::HAL::DeviceQueueDeallocaOp>(
loc, device, queueAffinity, waitFence, signalFence,
adaptor.getOperand());
Expand Down Expand Up @@ -873,7 +906,8 @@ struct CmdExecuteOpPattern
IREE::Stream::CmdExecuteOp executeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = executeOp.getLoc();
auto device = lookupDeviceFor(executeOp, rewriter);
auto [device, queueAffinity] =
lookupDeviceAndQueueAffinityFor(executeOp, rewriter);

// If there are any wait timepoints it means there's prior queued execution
// that we may need to wait behind and we can't execute inline. HAL
Expand Down Expand Up @@ -918,7 +952,6 @@ struct CmdExecuteOpPattern
loc, device, executeOp.getResultTimepoint(), rewriter);

// Queue execution.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
rewriter.create<IREE::HAL::DeviceQueueExecuteOp>(loc, device, queueAffinity,
waitFence, signalFence,
ValueRange{commandBuffer});
Expand Down Expand Up @@ -1024,9 +1057,8 @@ struct TimepointChainExternalOpPattern
return rewriter.notifyMatchFailure(
exportOp, "only exports to HAL fences are supported");
}
auto device = lookupDeviceFor(exportOp, rewriter);
auto queueAffinity =
rewriter.create<arith::ConstantIntOp>(exportOp.getLoc(), -1, 64);
auto [device, queueAffinity] =
lookupDeviceAndQueueAffinityFor(exportOp, rewriter);
rewriter.replaceOpWithNewOp<IREE::HAL::DeviceQueueExecuteOp>(
exportOp, device, queueAffinity,
/*wait_fence=*/adaptor.getAwaitTimepoint(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,23 @@ func.func @cmdDispatch(%arg0: !stream.resource<transient>, %arg1: index, %arg2:
// CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]]
return %0 : !stream.timepoint
}

// -----

// Tests that an operation specified to run on multiple queues ends up with the
// appropriate queue affinity mask. The final affinity is the result of ORing
// the target affinities (0b01 | 0b10 = 0b11 = 3).

// CHECK-LABEL: @cmdExecuteAffinities
func.func @cmdExecuteAffinities(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
%0 = stream.cmd.execute on(#hal.affinity.queue<[0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) {
stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
} => !stream.timepoint
// CHECK: hal.device.queue.execute
// CHECK-SAME: affinity(%c3_i64)
// CHECK-SAME: commands([%[[CMD]]])
return %0 : !stream.timepoint
}
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ td_library(
include = ["*.td"],
),
deps = [
"//compiler/src/iree/compiler/Dialect/Stream/IR:td_files",
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:FuncTdFiles",
Expand Down Expand Up @@ -71,6 +72,7 @@ iree_compiler_cc_library(
":HALInterfacesGen",
":HALOpsGen",
":HALTypesGen",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ iree_cc_library(
MLIRSupport
MLIRTransformUtils
MLIRViewLikeInterface
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
PUBLIC
Expand Down
44 changes: 44 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

include "iree/compiler/Dialect/HAL/IR/HALDialect.td"
include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td"
include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilTypes.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
Expand Down Expand Up @@ -597,11 +598,13 @@ def HAL_ExecutableTargetAttr :
one of many different executable targets. Assume an N:M mapping between the
two in all cases.
}];

let parameters = (ins
AttrParameter<"StringAttr", "">:$backend,
AttrParameter<"StringAttr", "">:$format,
AttrParameter<"DictionaryAttr", "">:$configuration
);

let builders = [
AttrBuilder<(ins "StringRef":$backend, "StringRef":$format)>,
];
Expand All @@ -615,6 +618,47 @@ def HAL_ExecutableTargetAttr :
// device that can load an executable of this target.
Attribute getMatchExpression();
}];

let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// #hal.affinity.queue<*>
//===----------------------------------------------------------------------===//

def HAL_AffinityQueueAttr : AttrDef<HAL_Dialect, "AffinityQueue", [
DeclareAttrInterfaceMethods<Stream_AffinityAttr, [
"isExecutableWith",
"joinOR",
"joinAND",
]>,
]> {
let mnemonic = "affinity.queue";
let summary = [{specifies a set of allowed queues for an operation}];
let description = [{
WIP; see https://github.com/iree-org/iree/issues/10765.
This may change in the future to either be a nested attribute on a larger
affinity struct or be defined by an implementation of the affinity attr
interface. For now this allows higher levels of the stack to specify
queues such that the stream dialect can understand them and they can be
lowered into the HAL dialect.

Specifies that an annotated operation or scope is only allowed to execute on
the set of queues (0-64) provided. Operations will not run on other queues.

Example:
```mlir
// any queue
#hal.affinity.queue<*>
// queues 4 and 5
#hal.affinity.queue<[4, 5]>
```
}];

let parameters = (ins
AttrParameter<"int64_t", "">:$mask
);

let hasCustomAssemblyFormat = 1;
}

Expand Down
Loading

0 comments on commit 87c5f3f

Please sign in to comment.