Skip to content

Commit

Permalink
Merge pull request #9754 from iree-org/benvanik-timepoint-to-hal
Browse files Browse the repository at this point in the history
Adding compiler/runtime support for lowering the asynchronous stream dialect ops into HAL ops, materializing a timeline (today just one but multiple in the future), and passing through to the runtime HAL module. This allowed for the removal of the existing placeholder submit_and_wait op and enables queue-ordered allocations to be implemented in the HAL.

This is likely not the final design but unblocks work on coroutines, queue-ordered allocations, webgpu, and plumbing fences through the user-facing API/native ABI. Future refinements may create overrides that use semaphores instead of fences to avoid fence heap allocations when not required, but for most single-function classic ML models once we plumb fences through the ABI no internal fences are required. The current timeline materialization also strictly orders all invocations where instead we should be able to elide those when there's no internal program state to protect.

Because the various HAL backends all need work (CUDA/ROCM in particular need massive work) nearly everything is synchronized exactly as it was before but now that synchronization happens in the IR we emit and we can selectively start supporting async per target.

Progress on #1285 (just need to put fences on the ABI!).
Progress on #8093 (added yieldable fence waits).
Progress on #9572 (added compiler/runtime glue for queue-ordered allocs).
  • Loading branch information
benvanik authored Aug 2, 2022
2 parents d46c881 + b1688f4 commit 76bf2f3
Show file tree
Hide file tree
Showing 78 changed files with 3,345 additions and 444 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ void populateHALDeviceToVMPatterns(MLIRContext *context,
patterns.insert<DeviceQueryIntCastOpConversion>(context, typeConverter);
patterns.insert<DeviceQueryI64OpConversion>(
context, importSymbols, typeConverter, "hal.device.query.i64");

patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueAllocaOp>>(
context, importSymbols, typeConverter, "hal.device.queue.alloca");
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueDeallocaOp>>(
context, importSymbols, typeConverter, "hal.device.queue.dealloca");
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueExecuteOp>>(
context, importSymbols, typeConverter, "hal.device.queue.execute");
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueFlushOp>>(
context, importSymbols, typeConverter, "hal.device.queue.flush");
}

} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ void populateHALExperimentalToVMPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
patterns.insert<VMImportOpConversion<IREE::HAL::ExSharedDeviceOp>>(
context, importSymbols, typeConverter, "hal.ex.shared_device");
patterns.insert<VMImportOpConversion<IREE::HAL::ExSubmitAndWaitOp>>(
context, importSymbols, typeConverter, "hal.ex.submit_and_wait");
}

} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,82 @@ func.func @device_query_i1_default(%device: !hal.device) -> i1 {
// CHECK: return %[[OUT]]
return %value : i1
}

// -----

// CHECK-LABEL: @device_queue_alloca
func.func @device_queue_alloca(
// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64,
%device: !hal.device, %affinity: i64,
// CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL_FENCE:.+]]: !vm.ref<!hal.fence>,
%wait_fence: !hal.fence, %signal_fence: !hal.fence,
// CHECK-SAME: %[[SIZE_I32:.+]]: i32)
%size: index) -> !hal.buffer {
%c100_i64 = arith.constant 100 : i64
// CHECK: %[[SIZE_I64:.+]] = vm.ext.i32.i64.s %[[SIZE_I32]]
// CHECK: = vm.call @hal.device.queue.alloca(
// CHECK-SAME: %[[DEVICE]], %[[AFFINITY]],
// CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]],
// CHECK-SAME: %c100, %c48, %c3, %[[SIZE_I64]])
%buffer = hal.device.queue.alloca<%device : !hal.device>
affinity(%affinity)
wait(%wait_fence) signal(%signal_fence)
pool(%c100_i64)
type(DeviceLocal) usage(Transfer)
: !hal.buffer{%size}
return %buffer : !hal.buffer
}

// -----

// CHECK-LABEL: @device_queue_dealloca
func.func @device_queue_dealloca(
// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64,
%device: !hal.device, %affinity: i64,
// CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL_FENCE:.+]]: !vm.ref<!hal.fence>,
%wait_fence: !hal.fence, %signal_fence: !hal.fence,
// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>)
%buffer: !hal.buffer) {
// CHECK: vm.call @hal.device.queue.dealloca(
// CHECK-SAME: %[[DEVICE]], %[[AFFINITY]],
// CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]],
// CHECK-SAME: %[[BUFFER]])
hal.device.queue.dealloca<%device : !hal.device>
affinity(%affinity)
wait(%wait_fence) signal(%signal_fence)
buffer(%buffer : !hal.buffer)
return
}

// -----

// CHECK-LABEL: @device_queue_execute
func.func @device_queue_execute(
// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64,
%device: !hal.device, %affinity: i64,
// CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL_FENCE:.+]]: !vm.ref<!hal.fence>,
%wait_fence: !hal.fence, %signal_fence: !hal.fence,
// CHECK-SAME: %[[CMD0:.+]]: !vm.ref<!hal.command_buffer>, %[[CMD1:.+]]: !vm.ref<!hal.command_buffer>)
%cmd0: !hal.command_buffer, %cmd1: !hal.command_buffer) {
// CHECK: vm.call.variadic @hal.device.queue.execute(
// CHECK-SAME: %[[DEVICE]], %[[AFFINITY]],
// CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]],
// CHECK-SAME: [%[[CMD0]], %[[CMD1]]])
hal.device.queue.execute<%device : !hal.device>
affinity(%affinity)
wait(%wait_fence) signal(%signal_fence)
commands([%cmd0, %cmd1])
return
}

// -----

// CHECK-LABEL: @device_queue_flush
func.func @device_queue_flush(
// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64)
%device: !hal.device, %affinity: i64) {
// CHECK: vm.call @hal.device.queue.flush(%[[DEVICE]], %[[AFFINITY]])
hal.device.queue.flush<%device : !hal.device>
affinity(%affinity)
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ static Value lookupAllocatorFor(Operation *op, OpBuilder &builder) {
return allocatorOp.getResult();
}

// Returns the |timepointFence| or a util.null.
static Value getOrCreateWaitFence(Location loc, Value timepointFence,
OpBuilder &builder) {
if (timepointFence) return timepointFence;
return builder.create<IREE::Util::NullOp>(
loc, builder.getType<IREE::HAL::FenceType>());
}

// Scans all of the stream.cmd.* ops in the region to derive a command category.
static IREE::HAL::CommandCategoryBitfield deriveCommandCategories(
Region &region) {
Expand Down Expand Up @@ -229,7 +237,8 @@ struct ResourceAllocaOpPattern
LogicalResult matchAndRewrite(
IREE::Stream::ResourceAllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto allocator = lookupAllocatorFor(allocaOp, rewriter);
auto loc = allocaOp.getLoc();
auto device = lookupDeviceFor(allocaOp, rewriter);
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();

// Transient allocations are device-local. Copies are required to get their
Expand All @@ -243,16 +252,20 @@ struct ResourceAllocaOpPattern
auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer |
IREE::HAL::BufferUsageBitfield::DispatchStorage;

auto allocateOp = rewriter.create<IREE::HAL::AllocatorAllocateOp>(
allocaOp.getLoc(), bufferType, allocator, memoryTypes, bufferUsage,
adaptor.getStorageSize());
// Gather wait/signal fence, which are optional.
Value waitFence =
getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
Value signalFence = rewriter.create<IREE::HAL::TimelineAdvanceOp>(
loc, rewriter.getType<IREE::HAL::FenceType>());

// TODO(benvanik): stream ordered allocations.
auto resolvedTimepoint =
rewriter.create<arith::ConstantIntOp>(allocaOp.getLoc(), 0, 64)
.getResult();
// Queue allocation.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
auto pool = rewriter.create<arith::ConstantIntOp>(loc, 0, 64);
auto allocateOp = rewriter.create<IREE::HAL::DeviceQueueAllocaOp>(
loc, bufferType, device, queueAffinity, waitFence, signalFence, pool,
memoryTypes, bufferUsage, adaptor.getStorageSize());

rewriter.replaceOp(allocaOp, {allocateOp.getResult(), resolvedTimepoint});
rewriter.replaceOp(allocaOp, {allocateOp.getResult(), signalFence});
return success();
}
};
Expand All @@ -263,11 +276,22 @@ struct ResourceDeallocaOpPattern
LogicalResult matchAndRewrite(
IREE::Stream::ResourceDeallocaOp deallocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): stream ordered allocations.
auto resolvedTimepoint =
rewriter.create<arith::ConstantIntOp>(deallocaOp.getLoc(), 0, 64)
.getResult();
rewriter.replaceOp(deallocaOp, {resolvedTimepoint});
auto loc = deallocaOp.getLoc();
auto device = lookupDeviceFor(deallocaOp, rewriter);

// Gather wait/signal fence, which are optional.
Value waitFence =
getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
Value signalFence = rewriter.create<IREE::HAL::TimelineAdvanceOp>(
loc, rewriter.getType<IREE::HAL::FenceType>());

// Queue allocation.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
rewriter.create<IREE::HAL::DeviceQueueDeallocaOp>(
loc, device, queueAffinity, waitFence, signalFence,
adaptor.getOperand());

rewriter.replaceOp(deallocaOp, {signalFence});
return success();
}
};
Expand Down Expand Up @@ -792,18 +816,22 @@ struct CmdExecuteOpPattern
auto loc = executeOp.getLoc();
auto device = lookupDeviceFor(executeOp, rewriter);

// TODO(benvanik): disable inline execution once we have semaphores.
// We can look ahead to see if there's an await immediately to trigger the
// inline execution.
auto modes = IREE::HAL::CommandBufferModeBitfield::OneShot |
IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution;
// If there are any wait timepoints it means there's prior queued execution
// that we may need to wait behind and we can't execute inline. HAL
// implementations may be able to flush eagerly if they are able to tell
// that all conditions are met during recording but we leave that to them.
auto modes = IREE::HAL::CommandBufferModeBitfield::OneShot;
if (!executeOp.getAwaitTimepoint()) {
modes =
modes | IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution;
}

// Derive the command buffer type based on the kind of operations present.
// This can help the submission get routed to appropriate hardware queues
// (like dedicated DMA controllers).
auto commandCategories = deriveCommandCategories(executeOp.getBody());

// Create a new command buffer for recording. If we were
// Create a new command buffer for recording.
auto commandBuffer =
rewriter
.create<IREE::HAL::CommandBufferCreateOp>(
Expand All @@ -824,14 +852,19 @@ struct CmdExecuteOpPattern
rewriter.mergeBlockBefore(&executeOp.getBody().front(), endOp,
adaptor.getResourceOperands());

// TODO(benvanik): we should queue a submit here with the semaphore instead.
rewriter.create<IREE::HAL::ExSubmitAndWaitOp>(loc, device, commandBuffer);
// Gather wait/signal fence, which are optional.
Value waitFence =
getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
Value signalFence = rewriter.create<IREE::HAL::TimelineAdvanceOp>(
loc, rewriter.getType<IREE::HAL::FenceType>());

// TODO(benvanik): propagate semaphore information.
auto resolvedTimepoint =
rewriter.create<arith::ConstantIntOp>(loc, 0, 64).getResult();
// Queue execution.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
rewriter.create<IREE::HAL::DeviceQueueExecuteOp>(loc, device, queueAffinity,
waitFence, signalFence,
ValueRange{commandBuffer});

rewriter.replaceOp(executeOp, resolvedTimepoint);
rewriter.replaceOp(executeOp, signalFence);
return success();
}
};
Expand Down Expand Up @@ -877,8 +910,8 @@ struct TimepointImmediateOpPattern
LogicalResult matchAndRewrite(
IREE::Stream::TimepointImmediateOp immediateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): model timepoints as semaphores.
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(immediateOp, 0, 64);
rewriter.replaceOpWithNewOp<IREE::Util::NullOp>(
immediateOp, rewriter.getType<IREE::HAL::FenceType>());
return success();
}
};
Expand All @@ -889,25 +922,24 @@ struct TimepointImportOpPattern
LogicalResult matchAndRewrite(
IREE::Stream::TimepointImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle imports from HAL semaphores.
auto operands = adaptor.operands();
if (operands.size() != 2 ||
!operands[0].getType().isa<IREE::HAL::SemaphoreType>() ||
!operands[1].getType().isIntOrIndex()) {
// Only handle imports from HAL semaphores _or_ fences.
auto operands = adaptor.getOperands();
if (operands.size() == 1 &&
operands[0].getType().isa<IREE::HAL::FenceType>()) {
rewriter.replaceOp(importOp, operands[0]);
return success();
} else if (operands.size() == 2 &&
operands[0].getType().isa<IREE::HAL::SemaphoreType>() &&
operands[1].getType().isIntOrIndex()) {
rewriter.replaceOpWithNewOp<IREE::HAL::FenceCreateOp>(
importOp, rewriter.getType<IREE::HAL::FenceType>(),
ValueRange{operands[0]}, ValueRange{operands[1]});
return success();
} else {
return rewriter.notifyMatchFailure(importOp,
"only imports from HAL semaphore + "
"sequence value tuples are supported");
}

// TODO(benvanik): model timepoints as semaphores.
// For now we just block on the semaphore.
auto awaitOp = rewriter.create<IREE::HAL::SemaphoreAwaitOp>(
importOp.getLoc(), rewriter.getI32Type(), operands[0], operands[1]);
rewriter.create<IREE::Util::StatusCheckOkOp>(
importOp.getLoc(), awaitOp.getStatus(),
"failed to wait on imported semaphore");
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(importOp, 0, 64);
return success();
}
};

Expand All @@ -917,24 +949,13 @@ struct TimepointExportOpPattern
LogicalResult matchAndRewrite(
IREE::Stream::TimepointExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle exports into HAL semaphores.
if (exportOp.getNumResults() != 2 ||
!exportOp.getResult(0).getType().isa<IREE::HAL::SemaphoreType>() ||
!exportOp.getResult(1).getType().isIntOrIndex()) {
return rewriter.notifyMatchFailure(exportOp,
"only exports to HAL semaphore + "
"sequence value tuples are supported");
// Only handle exports into HAL fences.
if (exportOp.getNumResults() != 1 ||
!exportOp.getResult(0).getType().isa<IREE::HAL::FenceType>()) {
return rewriter.notifyMatchFailure(
exportOp, "only exports to HAL fences are supported");
}

auto loc = exportOp.getLoc();
auto device = lookupDeviceFor(exportOp, rewriter);

// TODO(benvanik): model timepoints as semaphores.
// For now we just create a signaled semaphore.
auto exportValue = rewriter.create<arith::ConstantIntOp>(loc, 0, 64);
auto exportSemaphore = rewriter.create<IREE::HAL::SemaphoreCreateOp>(
loc, rewriter.getType<IREE::HAL::SemaphoreType>(), device, exportValue);
rewriter.replaceOp(exportOp, {exportSemaphore, exportValue});
rewriter.replaceOp(exportOp, adaptor.getAwaitTimepoint());
return success();
}
};
Expand All @@ -945,11 +966,9 @@ struct TimepointJoinOpPattern
LogicalResult matchAndRewrite(
IREE::Stream::TimepointJoinOp joinOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): model timepoints as semaphores.
// This should be a max() of the operand timepoints. Could be done with
// affine expressions, but since everything is always 0 we just max(0,0)=0
// here :)
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(joinOp, 0, 64);
rewriter.replaceOpWithNewOp<IREE::HAL::FenceJoinOp>(
joinOp, rewriter.getType<IREE::HAL::FenceType>(),
adaptor.getAwaitTimepoints());
return success();
}
};
Expand All @@ -960,7 +979,16 @@ struct TimepointAwaitOpPattern
LogicalResult matchAndRewrite(
IREE::Stream::TimepointAwaitOp awaitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): model timepoints as semaphores.
auto loc = awaitOp.getLoc();

// Perform the blocking wait.
Value timeoutMillis = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
auto fenceOp = rewriter.create<IREE::HAL::FenceAwaitOp>(
loc, rewriter.getI32Type(), timeoutMillis, adaptor.getAwaitTimepoint());
rewriter.create<IREE::Util::StatusCheckOkOp>(loc, fenceOp.getStatus(),
"failed to wait on timepoint");

// Pass along operands.
rewriter.replaceOp(awaitOp, adaptor.getResourceOperands());
return success();
}
Expand Down Expand Up @@ -988,8 +1016,7 @@ struct GlobalTimepointConversionPattern
auto initialValue = op.getInitialValue();
if (!initialValue.hasValue()) return failure();
if (!initialValue->isa<IREE::Stream::TimepointAttr>()) return failure();
rewriter.updateRootInPlace(
op, [&]() { op.setInitialValueAttr(rewriter.getI64IntegerAttr(0)); });
rewriter.updateRootInPlace(op, [&]() { op.removeInitial_valueAttr(); });
return success();
}
};
Expand All @@ -1011,11 +1038,7 @@ void populateStreamToHALPatterns(MLIRContext *context,

typeConverter.addConversion(
[=](IREE::Stream::TimepointType type, SmallVectorImpl<Type> &results) {
// TODO(benvanik): model timepoints as semaphores.
// This may become a !hal.semaphore + index, or some !hal.timepoint that
// we then do more analysis on once we know what devices are in use
// where.
results.push_back(IntegerType::get(context, 64));
results.push_back(IREE::HAL::FenceType::get(context));
return success();
});

Expand Down
Loading

0 comments on commit 76bf2f3

Please sign in to comment.