Skip to content

Commit

Permalink
Plumbing support for nested command buffers and binding tables.
Browse files Browse the repository at this point in the history
Secondary command buffers can now be executed from primary command
buffers via iree_hal_command_buffer_execute_commands. During recording
of nested command buffers push descriptors can indirectly reference
slots in a binding table provided with each execution request. This
enables the same reusable command buffer to be executed many times
with unique bindings (even with prior execution in-flight), which is
a common pattern with queue-ordered allocations.

In the future we could allow the indirect bindings on primary command
buffers as well but that requires more work in each backend to support
and for now making it nested-only lets us turn on the feature
incrementally.

The compiler has the HAL ops modeled but nothing is lowering into them
yet; a pass that memoizes portions of streams and sets up the indirect
binding references is required.

Progress on #10144.
  • Loading branch information
benvanik committed Aug 23, 2022
1 parent 57ec69d commit aee7b7c
Show file tree
Hide file tree
Showing 43 changed files with 797 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,59 @@ namespace mlir {
namespace iree_compiler {
namespace {

// TODO(benvanik): import op handling of optional values.
// It'd be nice if the Optional<Index>:$binding_capacity could be emitted as 0
// when not present; today it'll be omitted entirely (as it's not in the operand
// set) but we need it for the fixed call signature.
class CommandBufferCreateOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferCreateOp> {
public:
CommandBufferCreateOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(typeConverter, context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult matchAndRewrite(
IREE::HAL::CommandBufferCreateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();

SmallVector<Value, 8> callOperands = {
adaptor.getDevice(),
};
auto modesValue = detail::rewriteAttrToOperands(
op.getLoc(), adaptor.getModesAttr(), rewriter.getI32Type(), rewriter);
if (!modesValue.has_value()) return failure();
callOperands.append(modesValue.value());
auto categoriesValue = detail::rewriteAttrToOperands(
op.getLoc(), adaptor.getCommandCategoriesAttr(), rewriter.getI32Type(),
rewriter);
if (!categoriesValue.has_value()) return failure();
callOperands.append(categoriesValue.value());
if (adaptor.getBindingCapacity()) {
callOperands.push_back(castToImportType(adaptor.getBindingCapacity(),
rewriter.getI32Type(), rewriter));
} else {
callOperands.push_back(
rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc()));
}

auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(),
callOperands);

copyImportAttrs(importOp, callOp);
return success();
}

private:
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferFillBufferOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferFillBufferOp> {
public:
Expand Down Expand Up @@ -88,6 +141,23 @@ class CommandBufferPushDescriptorSetOpConversion
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();

Value zero;
auto getI32Zero = [&]() {
if (!zero) {
zero = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());
}
return zero;
};
Value null;
auto getNull = [&]() {
if (!null) {
null = rewriter.create<IREE::VM::ConstRefZeroOp>(
op.getLoc(),
IREE::VM::RefType::get(rewriter.getType<IREE::HAL::BufferType>()));
}
return null;
};

SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
adaptor.getPipelineLayout(),
Expand All @@ -102,7 +172,16 @@ class CommandBufferPushDescriptorSetOpConversion
};
for (size_t i = 0; i < adaptor.getBindingOrdinals().size(); ++i) {
callOperands.push_back(adaptor.getBindingOrdinals()[i]);
callOperands.push_back(adaptor.getBindingBuffers()[i]);
auto bindingBuffer = adaptor.getBindingBuffers()[i];
if (bindingBuffer.getType().isa<IREE::VM::RefType>()) {
// Buffer binding; pass 0 for table slot.
callOperands.push_back(getI32Zero());
callOperands.push_back(bindingBuffer);
} else {
// Binding table reference; pass null for the buffer.
callOperands.push_back(bindingBuffer);
callOperands.push_back(getNull());
}
callOperands.push_back(castToImportType(adaptor.getBindingOffsets()[i],
rewriter.getI64Type(), rewriter));
callOperands.push_back(castToImportType(adaptor.getBindingLengths()[i],
Expand All @@ -126,7 +205,7 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCreateOp>>(
patterns.insert<CommandBufferCreateOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.create");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferFinalizeOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.finalize");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

// CHECK-LABEL: @command_buffer_create
func.func @command_buffer_create(%arg0: !hal.device) {
// CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3) : (!vm.ref<!hal.device>, i32, i32) -> !vm.ref<!hal.command_buffer>
// CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3, %zero) : (!vm.ref<!hal.device>, i32, i32, i32) -> !vm.ref<!hal.command_buffer>
%cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
return
}

// -----

// CHECK-LABEL: @command_buffer_create_bindings
func.func @command_buffer_create_bindings(%arg0: !hal.device, %arg1: index) {
// CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3, %arg1) : (!vm.ref<!hal.device>, i32, i32, i32) -> !vm.ref<!hal.command_buffer>
%cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot") categories("Transfer|Dispatch") bindings(%arg1) : !hal.command_buffer
return
}

// -----

// CHECK-LABEL: @command_buffer_finalize
func.func @command_buffer_finalize(%arg0: !hal.command_buffer) {
// CHECK: vm.call @hal.command_buffer.finalize(%arg0) : (!vm.ref<!hal.command_buffer>) -> ()
Expand Down Expand Up @@ -107,6 +116,41 @@ func.func @command_buffer_copy_buffer(

// -----

// CHECK-LABEL: @command_buffer_push_descriptor_set
// CHECK-SAME: %[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref<!hal.pipeline_layout>,
// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
// CHECK-SAME: %[[SLOT:.+]]: i32
func.func @command_buffer_push_descriptor_set(
%cmd: !hal.command_buffer,
%layout: !hal.pipeline_layout,
%buffer: !hal.buffer,
%slot: index
) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c4096 = arith.constant 4096 : index
%c8000 = arith.constant 8000 : index
// CHECK: %[[C0:.+]] = vm.const.i32.zero
// CHECK: %[[C1:.+]] = vm.const.i32 1
// CHECK: %[[NULL:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
// CHECK: vm.call.variadic @hal.command_buffer.push_descriptor_set
// CHECK-SAME: (%[[CMD]], %[[LAYOUT]], %c1, [
// CHECK-SAME: (%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000),
// CHECK-SAME: (%[[C1]], %[[SLOT]], %[[NULL]], %c4, %c4096)
// CHECK-SAME: ]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.pipeline_layout>, i32, tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
hal.command_buffer.push_descriptor_set<%cmd : !hal.command_buffer>
layout(%layout : !hal.pipeline_layout)[%c1]
bindings([
%c0 = (%buffer : !hal.buffer)[%c4096, %c8000],
%c1 = (%slot : index)[%c4, %c4096]
])
return
}

// -----

// CHECK-LABEL: @command_buffer_dispatch
func.func @command_buffer_dispatch(
%arg0: !hal.command_buffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ struct CmdExecuteOpPattern
rewriter
.create<IREE::HAL::CommandBufferCreateOp>(
loc, rewriter.getType<IREE::HAL::CommandBufferType>(), device,
modes, commandCategories)
modes, commandCategories, /*binding_capacity=*/Value{})
.getResult();
mapping->mapCommandBuffer(executeOp, commandBuffer);

Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def HAL_BufferUsageBitfieldAttr :

def HAL_CommandBufferMode_None : I32BitEnumAttrCase<"None", 0x0000>;
def HAL_CommandBufferMode_OneShot : I32BitEnumAttrCase<"OneShot", 0x0001>;
def HAL_CommandBufferMode_Nested : I32BitEnumAttrCase<"Nested", 0x0002>;
def HAL_CommandBufferMode_AllowInlineExecution : I32BitEnumAttrCase<"AllowInlineExecution", 0x0010>;
def HAL_CommandBufferModeBitfieldAttr :
I32BitEnumAttr<"CommandBufferModeBitfield", "valid CommandBufferMode", [
HAL_CommandBufferMode_None,
HAL_CommandBufferMode_OneShot,
HAL_CommandBufferMode_Nested,
HAL_CommandBufferMode_AllowInlineExecution,
]> {
let cppNamespace = "mlir::iree_compiler::IREE::HAL";
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,8 @@ def HAL_CommandBufferCreateOp : HAL_Op<"command_buffer.create", [
let arguments = (ins
HAL_Device:$device,
HAL_CommandBufferModeBitfieldAttr:$modes,
HAL_CommandCategoryBitfieldAttr:$command_categories
HAL_CommandCategoryBitfieldAttr:$command_categories,
Optional<Index>:$binding_capacity
);
let results = (outs
HAL_CommandBuffer:$result
Expand All @@ -673,6 +674,7 @@ def HAL_CommandBufferCreateOp : HAL_Op<"command_buffer.create", [
`device` `(` $device `:` type($device) `)`
`mode` `(` $modes `)`
`categories` `(` $command_categories `)`
(`bindings` `(` $binding_capacity^ `)`)?
`:` type($result)
attr-dict-with-keyword
}];
Expand Down Expand Up @@ -870,14 +872,16 @@ def HAL_CommandBufferPushDescriptorSetOp :
let summary = [{command buffer descriptor set push binding operation}];
let description = [{
Pushes an inline-defined descriptor set to the command buffer.
The provided buffers may either be HAL buffers or indirect references into
the command buffer binding table.
}];

let arguments = (ins
HAL_CommandBuffer:$command_buffer,
HAL_PipelineLayout:$pipeline_layout,
Index:$set,
Variadic<Index>:$binding_ordinals,
Variadic<HAL_BufferType>:$binding_buffers,
Variadic<AnyTypeOf<[Index, HAL_BufferType]>>:$binding_buffers,
Variadic<HAL_DeviceSize>:$binding_offsets,
Variadic<HAL_DeviceSize>:$binding_lengths
);
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ struct SemaphoreType : public Type::TypeBase<SemaphoreType, Type, TypeStorage> {
//===----------------------------------------------------------------------===//

// A tuple containing runtime values for a descriptor set binding.
// The buffer specified may be either a !hal.buffer or an index of a binding
// table slot to source the buffer from.
struct DescriptorSetBindingValue {
Value ordinal;
Value buffer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,38 @@ func.func @command_buffer_copy_buffer(

// -----

// CHECK-LABEL: @command_buffer_push_descriptor_set
// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer,
// CHECK-SAME: %[[LAYOUT:.+]]: !hal.pipeline_layout,
// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
// CHECK-SAME: %[[SLOT:.+]]: index
func.func @command_buffer_push_descriptor_set(
%cmd: !hal.command_buffer,
%layout: !hal.pipeline_layout,
%buffer: !hal.buffer,
%slot: index
) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c4096 = arith.constant 4096 : index
%c8000 = arith.constant 8000 : index
// CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer>
hal.command_buffer.push_descriptor_set<%cmd : !hal.command_buffer>
// CHECK-SAME: layout(%[[LAYOUT]] : !hal.pipeline_layout)[%c1]
layout(%layout : !hal.pipeline_layout)[%c1]
// CHECK-SAME: bindings([
bindings([
// CHECK-NEXT: %c0 = (%[[BUFFER]] : !hal.buffer)[%c4096, %c8000]
%c0 = (%buffer : !hal.buffer)[%c4096, %c8000],
// CHECK-NEXT: %c1 = (%[[SLOT]] : index)[%c4, %c4096]
%c1 = (%slot : index)[%c4, %c4096]
])
return
}

// -----

hal.executable @ex {
hal.executable.variant @backend, target = <"backend", "format"> {
hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ static void appendDispatchBenchmark(IREE::HAL::ExecutableOp executableOp,
funcBuilder
.create<IREE::HAL::CommandBufferCreateOp>(
loc, funcBuilder.getType<IREE::HAL::CommandBufferType>(), device,
commandBufferModes, IREE::HAL::CommandCategoryBitfield::Dispatch)
commandBufferModes, IREE::HAL::CommandCategoryBitfield::Dispatch,
/*binding_capacity=*/Value{})
.getResult();

// Get the layout required to set up the dispatches.
Expand Down
16 changes: 12 additions & 4 deletions compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ vm.import @buffer_view.trace(
vm.import @command_buffer.create(
%device : !vm.ref<!hal.device>,
%modes : i32,
%command_categories : i32
%command_categories : i32,
%binding_capacity : i32
) -> !vm.ref<!hal.command_buffer>

// Finalizes recording into the command buffer and prepares it for submission.
Expand Down Expand Up @@ -231,10 +232,9 @@ vm.import @command_buffer.push_descriptor_set(
%command_buffer : !vm.ref<!hal.command_buffer>,
%pipeline_layout : !vm.ref<!hal.pipeline_layout>,
%set : i32,
// <binding, buffer, offset, length>
%bindings : tuple<i32, !vm.ref<!hal.buffer>, i64, i64>...
// <binding, slot, buffer, offset, length>
%bindings : tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64>...
)

// Dispatches an execution request.
vm.import @command_buffer.dispatch(
%command_buffer : !vm.ref<!hal.command_buffer>,
Expand All @@ -255,6 +255,14 @@ vm.import @command_buffer.dispatch.indirect(
%workgroups_offset : i64
)

// Executes a secondary command buffer with the given binding table.
vm.import @command_buffer.execute.commands(
%command_buffer : !vm.ref<!hal.command_buffer>,
%commands : !vm.ref<!hal.command_buffer>,
// <buffer, offset, length>
%bindings : tuple<!vm.ref<!hal.buffer>, i64, i64>...
)

//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_layout_t
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 24 additions & 5 deletions experimental/rocm/direct_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,19 @@ iree_status_t iree_hal_rocm_direct_command_buffer_create(
iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context,
iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
iree_arena_block_pool_t* block_pool,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_ASSERT_ARGUMENT(context);
IREE_ASSERT_ARGUMENT(block_pool);
IREE_ASSERT_ARGUMENT(out_command_buffer);
*out_command_buffer = NULL;

if (binding_capacity > 0) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"indirect command buffers not yet implemented");
}

IREE_TRACE_ZONE_BEGIN(z0);

iree_hal_rocm_direct_command_buffer_t* command_buffer = NULL;
Expand All @@ -66,7 +73,7 @@ iree_status_t iree_hal_rocm_direct_command_buffer_create(
context->host_allocator, total_size, (void**)&command_buffer);
if (iree_status_is_ok(status)) {
iree_hal_command_buffer_initialize(
device, mode, command_categories, queue_affinity,
device, mode, command_categories, queue_affinity, binding_capacity,
&iree_hal_rocm_direct_command_buffer_vtable, &command_buffer->base);
command_buffer->context = context;
command_buffer->block_pool = block_pool;
Expand Down Expand Up @@ -310,9 +317,11 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set(
for (iree_host_size_t i = 0; i < binding_count; i++) {
iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
hipDeviceptr_t device_ptr =
iree_hal_rocm_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(binding.buffer)) +
iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
binding.buffer
? (iree_hal_rocm_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(binding.buffer)) +
iree_hal_buffer_byte_offset(binding.buffer) + binding.offset)
: NULL;
*((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) =
device_ptr;
}
Expand Down Expand Up @@ -363,6 +372,14 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch_indirect(
"need rocm implementation");
}

static iree_status_t iree_hal_rocm_direct_command_buffer_execute_commands(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_command_buffer_t* base_commands,
iree_hal_buffer_binding_table_t binding_table) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"indirect command buffers not yet implemented");
}

static const iree_hal_command_buffer_vtable_t
iree_hal_rocm_direct_command_buffer_vtable = {
.destroy = iree_hal_rocm_direct_command_buffer_destroy,
Expand All @@ -387,4 +404,6 @@ static const iree_hal_command_buffer_vtable_t
.dispatch = iree_hal_rocm_direct_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_rocm_direct_command_buffer_dispatch_indirect,
.execute_commands =
iree_hal_rocm_direct_command_buffer_execute_commands,
};
Loading

0 comments on commit aee7b7c

Please sign in to comment.