diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp index e0900defc840..afc1d4e75068 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp @@ -120,6 +120,15 @@ void populateHALDeviceToVMPatterns(MLIRContext *context, patterns.insert(context, typeConverter); patterns.insert( context, importSymbols, typeConverter, "hal.device.query.i64"); + + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.alloca"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.dealloca"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.execute"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.flush"); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp index 08609db18f01..93b3f1481de3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp @@ -17,8 +17,6 @@ void populateHALExperimentalToVMPatterns(MLIRContext *context, RewritePatternSet &patterns) { patterns.insert>( context, importSymbols, typeConverter, "hal.ex.shared_device"); - patterns.insert>( - context, importSymbols, typeConverter, "hal.ex.submit_and_wait"); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir index 56056db7e6c5..643e7bb40a3e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir @@ -94,3 +94,82 @@ func.func @device_query_i1_default(%device: !hal.device) -> i1 { // CHECK: return %[[OUT]] return %value : i1 } + +// ----- + +// CHECK-LABEL: @device_queue_alloca +func.func @device_queue_alloca( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SIZE_I32:.+]]: i32) + %size: index) -> !hal.buffer { + %c100_i64 = arith.constant 100 : i64 + // CHECK: %[[SIZE_I64:.+]] = vm.ext.i32.i64.s %[[SIZE_I32]] + // CHECK: = vm.call @hal.device.queue.alloca( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: %c100, %c48, %c3, %[[SIZE_I64]]) + %buffer = hal.device.queue.alloca<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + pool(%c100_i64) + type(DeviceLocal) usage(Transfer) + : !hal.buffer{%size} + return %buffer : !hal.buffer +} + +// ----- + +// CHECK-LABEL: @device_queue_dealloca +func.func @device_queue_dealloca( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[BUFFER:.+]]: !vm.ref) + %buffer: !hal.buffer) { + // CHECK: vm.call @hal.device.queue.dealloca( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: %[[BUFFER]]) + hal.device.queue.dealloca<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + buffer(%buffer : !hal.buffer) + return +} + +// ----- + +// CHECK-LABEL: @device_queue_execute +func.func @device_queue_execute( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[CMD0:.+]]: !vm.ref, %[[CMD1:.+]]: !vm.ref) + %cmd0: !hal.command_buffer, %cmd1: !hal.command_buffer) { + // CHECK: vm.call.variadic @hal.device.queue.execute( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: [%[[CMD0]], %[[CMD1]]]) + hal.device.queue.execute<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + commands([%cmd0, %cmd1]) + return +} + +// ----- + +// CHECK-LABEL: @device_queue_flush +func.func @device_queue_flush( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64) + %device: !hal.device, %affinity: i64) { + // CHECK: vm.call @hal.device.queue.flush(%[[DEVICE]], %[[AFFINITY]]) + hal.device.queue.flush<%device : !hal.device> + affinity(%affinity) + return +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp index 03e56fcdd19a..0f2b2b55c91d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp @@ -36,6 +36,14 @@ static Value lookupAllocatorFor(Operation *op, OpBuilder &builder) { return allocatorOp.getResult(); } +// Returns the |timepointFence| or a util.null. +static Value getOrCreateWaitFence(Location loc, Value timepointFence, + OpBuilder &builder) { + if (timepointFence) return timepointFence; + return builder.create( + loc, builder.getType()); +} + // Scans all of the stream.cmd.* ops in the region to derive a command category. static IREE::HAL::CommandCategoryBitfield deriveCommandCategories( Region ®ion) { @@ -229,7 +237,8 @@ struct ResourceAllocaOpPattern LogicalResult matchAndRewrite( IREE::Stream::ResourceAllocaOp allocaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto allocator = lookupAllocatorFor(allocaOp, rewriter); + auto loc = allocaOp.getLoc(); + auto device = lookupDeviceFor(allocaOp, rewriter); auto bufferType = rewriter.getType(); // Transient allocations are device-local. Copies are required to get their @@ -243,16 +252,20 @@ struct ResourceAllocaOpPattern auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer | IREE::HAL::BufferUsageBitfield::DispatchStorage; - auto allocateOp = rewriter.create( - allocaOp.getLoc(), bufferType, allocator, memoryTypes, bufferUsage, - adaptor.getStorageSize()); + // Gather wait/signal fence, which are optional. + Value waitFence = + getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter); + Value signalFence = rewriter.create( + loc, rewriter.getType()); - // TODO(benvanik): stream ordered allocations. - auto resolvedTimepoint = - rewriter.create(allocaOp.getLoc(), 0, 64) - .getResult(); + // Queue allocation. + auto queueAffinity = rewriter.create(loc, -1, 64); + auto pool = rewriter.create(loc, 0, 64); + auto allocateOp = rewriter.create( + loc, bufferType, device, queueAffinity, waitFence, signalFence, pool, + memoryTypes, bufferUsage, adaptor.getStorageSize()); - rewriter.replaceOp(allocaOp, {allocateOp.getResult(), resolvedTimepoint}); + rewriter.replaceOp(allocaOp, {allocateOp.getResult(), signalFence}); return success(); } }; @@ -263,11 +276,22 @@ struct ResourceDeallocaOpPattern LogicalResult matchAndRewrite( IREE::Stream::ResourceDeallocaOp deallocaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // TODO(benvanik): stream ordered allocations. - auto resolvedTimepoint = - rewriter.create(deallocaOp.getLoc(), 0, 64) - .getResult(); - rewriter.replaceOp(deallocaOp, {resolvedTimepoint}); + auto loc = deallocaOp.getLoc(); + auto device = lookupDeviceFor(deallocaOp, rewriter); + + // Gather wait/signal fence, which are optional. + Value waitFence = + getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter); + Value signalFence = rewriter.create( + loc, rewriter.getType()); + + // Queue allocation. + auto queueAffinity = rewriter.create(loc, -1, 64); + rewriter.create( + loc, device, queueAffinity, waitFence, signalFence, + adaptor.getOperand()); + + rewriter.replaceOp(deallocaOp, {signalFence}); return success(); } }; @@ -792,18 +816,22 @@ struct CmdExecuteOpPattern auto loc = executeOp.getLoc(); auto device = lookupDeviceFor(executeOp, rewriter); - // TODO(benvanik): disable inline execution once we have semaphores. - // We can look ahead to see if there's an await immediately to trigger the - // inline execution. - auto modes = IREE::HAL::CommandBufferModeBitfield::OneShot | - IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution; + // If there are any wait timepoints it means there's prior queued execution + // that we may need to wait behind and we can't execute inline. HAL + // implementations may be able to flush eagerly if they are able to tell + // that all conditions are met during recording but we leave that to them. + auto modes = IREE::HAL::CommandBufferModeBitfield::OneShot; + if (!executeOp.getAwaitTimepoint()) { + modes = + modes | IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution; + } // Derive the command buffer type based on the kind of operations present. // This can help the submission get routed to appropriate hardware queues // (like dedicated DMA controllers). auto commandCategories = deriveCommandCategories(executeOp.getBody()); - // Create a new command buffer for recording. If we were + // Create a new command buffer for recording. auto commandBuffer = rewriter .create( @@ -824,14 +852,19 @@ struct CmdExecuteOpPattern rewriter.mergeBlockBefore(&executeOp.getBody().front(), endOp, adaptor.getResourceOperands()); - // TODO(benvanik): we should queue a submit here with the semaphore instead. - rewriter.create(loc, device, commandBuffer); + // Gather wait/signal fence, which are optional. + Value waitFence = + getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter); + Value signalFence = rewriter.create( + loc, rewriter.getType()); - // TODO(benvanik): propagate semaphore information. - auto resolvedTimepoint = - rewriter.create(loc, 0, 64).getResult(); + // Queue execution. + auto queueAffinity = rewriter.create(loc, -1, 64); + rewriter.create(loc, device, queueAffinity, + waitFence, signalFence, + ValueRange{commandBuffer}); - rewriter.replaceOp(executeOp, resolvedTimepoint); + rewriter.replaceOp(executeOp, signalFence); return success(); } }; @@ -877,8 +910,8 @@ struct TimepointImmediateOpPattern LogicalResult matchAndRewrite( IREE::Stream::TimepointImmediateOp immediateOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // TODO(benvanik): model timepoints as semaphores. - rewriter.replaceOpWithNewOp(immediateOp, 0, 64); + rewriter.replaceOpWithNewOp( + immediateOp, rewriter.getType()); return success(); } }; @@ -889,25 +922,24 @@ struct TimepointImportOpPattern LogicalResult matchAndRewrite( IREE::Stream::TimepointImportOp importOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Only handle imports from HAL semaphores. - auto operands = adaptor.operands(); - if (operands.size() != 2 || - !operands[0].getType().isa() || - !operands[1].getType().isIntOrIndex()) { + // Only handle imports from HAL semaphores _or_ fences. + auto operands = adaptor.getOperands(); + if (operands.size() == 1 && + operands[0].getType().isa()) { + rewriter.replaceOp(importOp, operands[0]); + return success(); + } else if (operands.size() == 2 && + operands[0].getType().isa() && + operands[1].getType().isIntOrIndex()) { + rewriter.replaceOpWithNewOp( + importOp, rewriter.getType(), + ValueRange{operands[0]}, ValueRange{operands[1]}); + return success(); + } else { return rewriter.notifyMatchFailure(importOp, "only imports from HAL semaphore + " "sequence value tuples are supported"); } - - // TODO(benvanik): model timepoints as semaphores. - // For now we just block on the semaphore. - auto awaitOp = rewriter.create( - importOp.getLoc(), rewriter.getI32Type(), operands[0], operands[1]); - rewriter.create( - importOp.getLoc(), awaitOp.getStatus(), - "failed to wait on imported semaphore"); - rewriter.replaceOpWithNewOp(importOp, 0, 64); - return success(); } }; @@ -917,24 +949,13 @@ struct TimepointExportOpPattern LogicalResult matchAndRewrite( IREE::Stream::TimepointExportOp exportOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Only handle exports into HAL semaphores. - if (exportOp.getNumResults() != 2 || - !exportOp.getResult(0).getType().isa() || - !exportOp.getResult(1).getType().isIntOrIndex()) { - return rewriter.notifyMatchFailure(exportOp, - "only exports to HAL semaphore + " - "sequence value tuples are supported"); + // Only handle exports into HAL fences. + if (exportOp.getNumResults() != 1 || + !exportOp.getResult(0).getType().isa()) { + return rewriter.notifyMatchFailure( + exportOp, "only exports to HAL fences are supported"); } - - auto loc = exportOp.getLoc(); - auto device = lookupDeviceFor(exportOp, rewriter); - - // TODO(benvanik): model timepoints as semaphores. - // For now we just create a signaled semaphore. - auto exportValue = rewriter.create(loc, 0, 64); - auto exportSemaphore = rewriter.create( - loc, rewriter.getType(), device, exportValue); - rewriter.replaceOp(exportOp, {exportSemaphore, exportValue}); + rewriter.replaceOp(exportOp, adaptor.getAwaitTimepoint()); return success(); } }; @@ -945,11 +966,9 @@ struct TimepointJoinOpPattern LogicalResult matchAndRewrite( IREE::Stream::TimepointJoinOp joinOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // TODO(benvanik): model timepoints as semaphores. - // This should be a max() of the operand timepoints. Could be done with - // affine expressions, but since everything is always 0 we just max(0,0)=0 - // here :) - rewriter.replaceOpWithNewOp(joinOp, 0, 64); + rewriter.replaceOpWithNewOp( + joinOp, rewriter.getType(), + adaptor.getAwaitTimepoints()); return success(); } }; @@ -960,7 +979,16 @@ struct TimepointAwaitOpPattern LogicalResult matchAndRewrite( IREE::Stream::TimepointAwaitOp awaitOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // TODO(benvanik): model timepoints as semaphores. + auto loc = awaitOp.getLoc(); + + // Perform the blocking wait. + Value timeoutMillis = rewriter.create(loc, -1, 32); + auto fenceOp = rewriter.create( + loc, rewriter.getI32Type(), timeoutMillis, adaptor.getAwaitTimepoint()); + rewriter.create(loc, fenceOp.getStatus(), + "failed to wait on timepoint"); + + // Pass along operands. rewriter.replaceOp(awaitOp, adaptor.getResourceOperands()); return success(); } @@ -988,8 +1016,7 @@ struct GlobalTimepointConversionPattern auto initialValue = op.getInitialValue(); if (!initialValue.hasValue()) return failure(); if (!initialValue->isa()) return failure(); - rewriter.updateRootInPlace( - op, [&]() { op.setInitialValueAttr(rewriter.getI64IntegerAttr(0)); }); + rewriter.updateRootInPlace(op, [&]() { op.removeInitial_valueAttr(); }); return success(); } }; @@ -1011,11 +1038,7 @@ void populateStreamToHALPatterns(MLIRContext *context, typeConverter.addConversion( [=](IREE::Stream::TimepointType type, SmallVectorImpl &results) { - // TODO(benvanik): model timepoints as semaphores. - // This may become a !hal.semaphore + index, or some !hal.timepoint that - // we then do more analysis on once we know what devices are in use - // where. - results.push_back(IntegerType::get(context, 64)); + results.push_back(IREE::HAL::FenceType::get(context)); return success(); }); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir index f667c16c6130..abc3695311e1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir @@ -17,7 +17,6 @@ func.func @cmdMemoryControl(%arg0: !stream.resource, %arg1: index) -> stream.cmd.discard %arg2[%c0 for %c128] : !stream.resource{%arg1} } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - // CHECK-NEXT: hal.ex.submit_and_wait return %0 : !stream.timepoint } @@ -37,7 +36,6 @@ func.func @cmdFill(%arg0: !stream.resource, %arg1: index) -> !stream. // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - // CHECK-NEXT: hal.ex.submit_and_wait return %0 : !stream.timepoint } @@ -57,7 +55,6 @@ func.func @cmdCopy(%arg0: !stream.resource, %arg1: index, %arg2: !str // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - // CHECK-NEXT: hal.ex.submit_and_wait return %0 : !stream.timepoint } @@ -93,7 +90,13 @@ func.func @cmdExecute(%arg0: !stream.resource, %arg1: index, %arg2: ! } } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - // CHECK-NEXT: hal.ex.submit_and_wait + // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance + // CHECK: hal.device.queue.execute + // CHECK-SAME: affinity(%c-1 + // CHECK-SAME: wait(%arg4) + // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) + // CHECK-SAME: commands([%[[CMD]]]) + // CHECK: return %[[SIGNAL_FENCE]] return %0 : !stream.timepoint } @@ -179,6 +182,5 @@ func.func @cmdDispatch(%arg0: !stream.resource, %arg1: index, %arg2: // CHECK: hal.command_buffer.execution_barrier<%[[CMD]] } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - // CHECK-NEXT: hal.ex.submit_and_wait return %0 : !stream.timepoint } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir index 139f1e56044e..6a1f1ee087a6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir @@ -17,45 +17,57 @@ func.func @resourceAlloc(%arg0: index, %arg1: index) -> (!stream.resource (!stream.resource, !stream.timepoint) { - // CHECK: %[[RET0:.+]] = hal.allocator.allocate +// CHECK-SAME: (%[[SIZE:.+]]: index) +func.func @resourceAlloca(%size: index) -> (!stream.resource, !stream.timepoint) { + // CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence + // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance + // CHECK: %[[RET0:.+]] = hal.device.queue.alloca + // CHECK-SAME: affinity(%c-1 + // CHECK-SAME: wait(%[[WAIT_FENCE]]) + // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) + // CHECK-SAME: pool(%c0 // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") - // CHECK-SAME: : !hal.buffer{%arg0} - %0:2 = stream.resource.alloca uninitialized : !stream.resource{%arg0} => !stream.timepoint - // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 - // CHECK: return %[[RET0]], %[[IMMEDIATE]] + // CHECK-SAME: : !hal.buffer{%[[SIZE]]} + %0:2 = stream.resource.alloca uninitialized : !stream.resource{%size} => !stream.timepoint + // CHECK: return %[[RET0]], %[[SIGNAL_FENCE]] return %0#0, %0#1 : !stream.resource, !stream.timepoint } // ----- -// TODO(#9572): implement stream ordered allocations. - // CHECK-LABEL: @resourceAllocaAwait -func.func @resourceAllocaAwait(%arg0: index, %await_timepoint: !stream.timepoint) -> (!stream.resource, !stream.timepoint) { - // CHECK: %[[RET0:.+]] = hal.allocator.allocate +// CHECK-SAME: (%[[SIZE:.+]]: index, %[[WAIT_FENCE:.+]]: !hal.fence) +func.func @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint) -> (!stream.resource, !stream.timepoint) { + // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance + // CHECK: %[[RET0:.+]] = hal.device.queue.alloca + // CHECK-SAME: affinity(%c-1 + // CHECK-SAME: wait(%[[WAIT_FENCE]]) + // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) + // CHECK-SAME: pool(%c0 // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") - // CHECK-SAME: : !hal.buffer{%arg0} - %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource{%arg0} => !stream.timepoint - // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 - // CHECK: return %[[RET0]], %[[IMMEDIATE]] + // CHECK-SAME: : !hal.buffer{%[[SIZE]]} + %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource{%size} => !stream.timepoint + // CHECK: return %[[RET0]], %[[SIGNAL_FENCE]] return %0#0, %0#1 : !stream.resource, !stream.timepoint } // ----- -// TODO(#9572): implement stream ordered allocations. - // CHECK-LABEL: @resourceDealloca -func.func @resourceDealloca(%arg0: index, %arg1: !stream.resource, %arg2: !stream.timepoint) -> !stream.timepoint { - %0 = stream.resource.dealloca %arg1 : !stream.resource{%arg0} => !stream.timepoint - // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 - // CHECK: return %[[IMMEDIATE]] +// CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer) +func.func @resourceDealloca(%size: index, %resource: !stream.resource) -> !stream.timepoint { + // CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence + // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance + // CHECK: hal.device.queue.dealloca + // CHECK-SAME: affinity(%c-1 + // CHECK-SAME: wait(%[[WAIT_FENCE]]) + // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) + // CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer) + %0 = stream.resource.dealloca %resource : !stream.resource{%size} => !stream.timepoint + // CHECK: return %[[SIGNAL_FENCE]] return %0 : !stream.timepoint } @@ -64,10 +76,16 @@ func.func @resourceDealloca(%arg0: index, %arg1: !stream.resource, %arg // TODO(#9572): implement stream ordered allocations. // CHECK-LABEL: @resourceDeallocaAwait -func.func @resourceDeallocaAwait(%arg0: index, %arg1: !stream.resource, %arg2: !stream.timepoint) -> !stream.timepoint { - %0 = stream.resource.dealloca await(%arg2) => %arg1 : !stream.resource{%arg0} => !stream.timepoint - // CHECK: %[[IMMEDIATE:.+]] = arith.constant 0 : i64 - // CHECK: return %[[IMMEDIATE]] +// CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer, %[[WAIT_FENCE:.+]]: !hal.fence) +func.func @resourceDeallocaAwait(%size: index, %resource: !stream.resource, %await_timepoint: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance + // CHECK: hal.device.queue.dealloca + // CHECK-SAME: affinity(%c-1 + // CHECK-SAME: wait(%[[WAIT_FENCE]]) + // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) + // CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer) + %0 = stream.resource.dealloca await(%await_timepoint) => %resource : !stream.resource{%size} => !stream.timepoint + // CHECK: return %[[SIGNAL_FENCE]] return %0 : !stream.timepoint } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir index 5fc9bd5b7112..80ae08118972 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir @@ -1,12 +1,8 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s -// TODO(#1285): implement timepoint lowering into HAL semaphores. -// For now all timepoints turn into ints and are mostly ignored. - -// CHECK-LABEL: @rwTimepoint -// CHECK-SAME: = 0 : i64 +// CHECK-LABEL: util.global private mutable @rwTimepoint : !hal.fence util.global private mutable @rwTimepoint = #stream.timepoint -// CHECK: func.func @globalTimepoint(%arg0: i64) -> i64 +// CHECK: func.func @globalTimepoint(%arg0: !hal.fence) -> !hal.fence func.func @globalTimepoint(%arg0: !stream.timepoint) -> !stream.timepoint { // CHECK: util.global.store %arg0, @rwTimepoint util.global.store %arg0, @rwTimepoint : !stream.timepoint @@ -20,42 +16,47 @@ func.func @globalTimepoint(%arg0: !stream.timepoint) -> !stream.timepoint { // CHECK-LABEL: @timepointImmediate func.func @timepointImmediate() -> !stream.timepoint { - // CHECK: %[[TIMEPOINT:.+]] = arith.constant 0 + // CHECK: %[[FENCE:.+]] = util.null : !hal.fence %0 = stream.timepoint.immediate => !stream.timepoint - // CHECK: return %[[TIMEPOINT]] + // CHECK: return %[[FENCE]] + return %0 : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @timepointImportFence +func.func @timepointImportFence(%arg0: !hal.fence) -> !stream.timepoint { + %0 = stream.timepoint.import %arg0 : (!hal.fence) => !stream.timepoint + // CHECK: return %arg0 return %0 : !stream.timepoint } // ----- -// CHECK-LABEL: @timepointImport -func.func @timepointImport(%arg0: !hal.semaphore, %arg1: i64) -> !stream.timepoint { - // CHECK: %[[WAIT_OK:.+]] = hal.semaphore.await<%arg0 : !hal.semaphore> until(%arg1) : i32 - // CHECK: util.status.check_ok %[[WAIT_OK]] - // CHECK: %[[TIMEPOINT:.+]] = arith.constant 0 +// CHECK-LABEL: @timepointImportSemaphore +func.func @timepointImportSemaphore(%arg0: !hal.semaphore, %arg1: i64) -> !stream.timepoint { + // CHECK: %[[FENCE:.+]] = hal.fence.create at<%arg0 : !hal.semaphore>(%arg1) -> !hal.fence %0 = stream.timepoint.import %arg0, %arg1 : (!hal.semaphore, i64) => !stream.timepoint - // CHECK: return %[[TIMEPOINT]] + // CHECK: return %[[FENCE]] return %0 : !stream.timepoint } // ----- -// CHECK-LABEL: @timepointExport -func.func @timepointExport(%arg0: !stream.timepoint) -> (!hal.semaphore, i64) { - // CHECK: %[[TIMEPOINT:.+]] = arith.constant 0 - // CHECK: %[[SEMAPHORE:.+]] = hal.semaphore.create device(%device : !hal.device) initial(%[[TIMEPOINT]]) : !hal.semaphore - %0:2 = stream.timepoint.export %arg0 => (!hal.semaphore, i64) - // CHECK: return %[[SEMAPHORE]], %[[TIMEPOINT]] - return %0#0, %0#1 : !hal.semaphore, i64 +// CHECK-LABEL: @timepointExportFence +func.func @timepointExportFence(%arg0: !stream.timepoint) -> !hal.fence { + %0 = stream.timepoint.export %arg0 => (!hal.fence) + // CHECK: return %arg0 + return %0 : !hal.fence } // ----- // CHECK-LABEL: @timepointJoin func.func @timepointJoin(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> !stream.timepoint { - // CHECK: %[[TIMEPOINT:.+]] = arith.constant 0 + // CHECK: %[[FENCE:.+]] = hal.fence.join at([%arg0, %arg1]) -> !hal.fence %0 = stream.timepoint.join max(%arg0, %arg1) => !stream.timepoint - // CHECK: return %[[TIMEPOINT]] + // CHECK: return %[[FENCE]] return %0 : !stream.timepoint } @@ -65,6 +66,8 @@ func.func @timepointJoin(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> func.func @timepointAwait(%arg0: !stream.timepoint, %arg1: !stream.resource, %arg2: !stream.resource<*>) -> (!stream.resource, !stream.resource<*>) { %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index + // CHECK: %[[WAIT_OK:.+]] = hal.fence.await until([%arg0]) timeout_millis(%c-1_i32) : i32 + // CHECK-NEXT: util.status.check_ok %[[WAIT_OK]] %0:2 = stream.timepoint.await %arg0 => %arg1, %arg2 : !stream.resource{%c100}, !stream.resource<*>{%c200} // CHECK: return %arg1, %arg2 return %0#0, %0#1 : !stream.resource, !stream.resource<*> diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td index 4ffbda7db60a..ea5ea9f2253d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td @@ -419,15 +419,13 @@ def HAL_WorkgroupSizeAttr : TypedArrayAttrBase< let constBuilderCall = "$_builder.getIndexArrayAttr($0)"; } -def HAL_CommandQueueAffinityAttr : SignlessIntElementsAttr<32> { - // TODO(b/143184519): add typeDescription support to other things. - // let description = [{ - // A bitmask defining which queues an operation is allowed to execute on. - // The selection is wrapped to the total number of available queues, so 0b0101 - // would enable queues 0 and 2 if there were four queues or queue 0 if there - // were two queues. - // }]; -} +// A bitmask defining which queues an operation is allowed to execute on. +// The selection is wrapped to the total number of available queues, so 0b0101 +// would enable queues 0 and 2 if there were four queues or queue 0 if there +// were two queues. +def HAL_DeviceQueueAffinity : TypeAlias; + +def HAL_DeviceQueuePool : TypeAlias; def HAL_DurationMillisAttr : SignlessIntElementsAttr<32> { // TODO(b/143184519): add typeDescription support to other things. @@ -568,6 +566,10 @@ def HAL_DeviceTargetAttr : // target device. Attribute getMatchExpression(); + // Returns true if there's an attribute with the given name in the + // configuration dictionary. + bool hasConfigurationAttr(StringRef name); + // Returns zero or more executable targets that this device supports. SmallVector getExecutableTargets(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 21bfcd8ead82..075621ff6e44 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -654,6 +654,21 @@ LogicalResult DeviceSwitchOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// hal.device.queue.* +//===----------------------------------------------------------------------===// + +void DeviceQueueAllocaOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "transient_buffer"); +} + +Value DeviceQueueAllocaOp::getOperandSize(unsigned idx) { return {}; } + +Value DeviceQueueAllocaOp::getResultSize(unsigned idx) { + return getResultSize(); +} + //===----------------------------------------------------------------------===// // hal.executable //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 02133354ef7d..67fd07ad1adb 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -45,15 +45,6 @@ def HAL_ExSharedDeviceOp : HAL_PureOp<"ex.shared_device", [ ]; } -def HAL_ExSubmitAndWaitOp : HAL_Op<"ex.submit_and_wait", [Util_YieldPoint]> { - let arguments = (ins - HAL_Device:$device, - HAL_CommandBuffer:$command_buffer - ); - - let assemblyFormat = "$device `,` $command_buffer attr-dict"; -} - //===----------------------------------------------------------------------===// // Pseudo ops for conversion support //===----------------------------------------------------------------------===// @@ -180,6 +171,34 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ let hasFolder = 1; } +// NOTE: this has side-effects as it is mutating the global timeline. +// Eventually we'll probably want a dedicated hal.timeline type instead. +def HAL_TimelineAdvanceOp : HAL_Op<"timeline.advance"> { + let summary = [{advances a program timeline by one step}]; + let description = [{ + Returns a fence indicating when the timeline has been advanced one step. + This fence can be used to wait until the timeline reaches or exceeds the + timepoint or used to signal the that it has. + + This is a pseudo-op that is expanded into a semaphore and target value + pair during timeline materialization. The op represents when the advancement + should occur in program order but not what the actual live timepoint would + be. + }]; + + // TODO(benvanik): discriminator when multiple devices or timelines are + // present. Today we only have a single timeline. + let arguments = (ins); + let results = (outs + HAL_Fence:$fence + ); + + let assemblyFormat = [{ + `:` type($fence) + attr-dict-with-keyword + }]; +} + //===----------------------------------------------------------------------===// // !hal.allocator / iree_hal_allocator_t //===----------------------------------------------------------------------===// @@ -1381,6 +1400,132 @@ def HAL_DeviceQueryOp : let hasVerifier = 1; } +def HAL_DeviceQueueAllocaOp : HAL_Op<"device.queue.alloca", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + ]> { + let summary = [{allocates a queue-ordered transient buffer}]; + let description = [{ + Returns a queue-ordered transient buffer that will be available for use when + the signal fence is reached. The allocation will not be made until the + wait fence has been reached. + + The size of the buffer returned may be larger than the requested size if the + allocator has specific alignment requirements or minimum allocation sizes. + + The buffer handle will remain live so long as there are retainers but the + contents are undefined before the allocation signal fence has been signaled + and after the deallocation wait fence has been reached. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + HAL_DeviceQueuePool:$pool, + HAL_MemoryTypeBitfieldAttr:$memory_types, + HAL_BufferUsageBitfieldAttr:$buffer_usage, + HAL_DeviceSize:$result_size + ); + let results = (outs + HAL_Buffer:$result + ); + + // TODO(benvanik): change type/usage to ref params. + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + `pool` `(` $pool `)` + `type` `(` $memory_types `)` + `usage` `(` $buffer_usage `)` + `:` custom(type($result), $result_size) + attr-dict-with-keyword + }]; +} + +def HAL_DeviceQueueDeallocaOp : HAL_Op<"device.queue.dealloca"> { + let summary = [{deallocates a queue-ordered transient buffer}]; + let description = [{ + Deallocates a queue-ordered transient buffer. + The deallocation will not be made until the wait fence has been reached and + once the storage is available for reuse the signal fence will be signaled. + + After deallocation the contents of the buffer may still be accessible but + will have undefined contents as other operations reuse the memory. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + HAL_Buffer:$buffer + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + `buffer` `(` $buffer `:` type($buffer) `)` + attr-dict-with-keyword + }]; +} + +def HAL_DeviceQueueExecuteOp : HAL_Op<"device.queue.execute"> { + let summary = [{enqueues command buffer execution}]; + let description = [{ + Executes one or more command buffers on a device queue. + The command buffers are executed in order as if they were recorded as one. + No commands will execute until the wait fence has been reached and the + signal fence will be signaled when all commands have completed. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + Variadic:$command_buffers + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + (`commands` `(` `[` $command_buffers^ `]` `)`)? + attr-dict-with-keyword + }]; +} + +def HAL_DeviceQueueFlushOp : HAL_Op<"device.queue.flush"> { + let summary = [{flushes locally-pending submissions to the queue}]; + let description = [{ + Flushes any locally-pending submissions in the queue. + When submitting many queue operations this can be used to eagerly flush + earlier submissions while later ones are still being constructed. + This may be a no-op. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + attr-dict-with-keyword + }]; +} + //===----------------------------------------------------------------------===// // !hal.executable / iree_hal_executable_t //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index 4847d1ce2d0f..c8cc659cbcbb 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -207,6 +207,11 @@ Attribute DeviceTargetAttr::getMatchExpression() { return DeviceMatchIDAttr::get(*this); } +bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { + auto configAttr = getConfiguration(); + return configAttr && configAttr.get(name); +} + SmallVector DeviceTargetAttr::getExecutableTargets() { SmallVector resultAttrs; auto configAttr = getConfiguration(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir index bf16460f13ce..94ba6d0f5f5d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir @@ -51,3 +51,84 @@ func.func @device_query(%device : !hal.device) -> (i1, i32) { %ok, %value = hal.device.query<%device : !hal.device> key("sys" :: "foo") : i1, i32 return %ok, %value : i1, i32 } + +// ----- + +// CHECK-LABEL: @device_queue_alloca +func.func @device_queue_alloca( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SIZE:.+]]: index) + %size: index) -> !hal.buffer { + %c100_i64 = arith.constant 100 : i64 + // CHECK: = hal.device.queue.alloca<%[[DEVICE]] : !hal.device> + %buffer = hal.device.queue.alloca<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: pool(%c100_i64) + pool(%c100_i64) + // CHECK-SAME: type({{.+}}) usage({{.+}}) + type(DeviceLocal) usage(Transfer) + // CHECK-SAME: : !hal.buffer{%[[SIZE]]} + : !hal.buffer{%size} + return %buffer : !hal.buffer +} + +// ----- + +// CHECK-LABEL: @device_queue_dealloca +func.func @device_queue_dealloca( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer) + %buffer: !hal.buffer) { + // CHECK: hal.device.queue.dealloca<%[[DEVICE]] : !hal.device> + hal.device.queue.dealloca<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: buffer(%[[BUFFER]] : !hal.buffer) + buffer(%buffer : !hal.buffer) + return +} + +// ----- + +// CHECK-LABEL: @device_queue_execute +func.func @device_queue_execute( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[CMD0:.+]]: !hal.command_buffer, %[[CMD1:.+]]: !hal.command_buffer) + %cmd0: !hal.command_buffer, %cmd1: !hal.command_buffer) { + // CHECK: hal.device.queue.execute<%[[DEVICE]] : !hal.device> + hal.device.queue.execute<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: commands([%[[CMD0]], %[[CMD1]]]) + commands([%cmd0, %cmd1]) + return +} + +// ----- + +// CHECK-LABEL: @device_queue_flush +func.func @device_queue_flush( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64) + %device: !hal.device, %affinity: i64) { + // CHECK: hal.device.queue.flush<%[[DEVICE]] : !hal.device> + hal.device.queue.flush<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + return +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir index 8136e4a56b5d..1c8cabb180de 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir @@ -6,14 +6,3 @@ func.func @shared_device() -> !hal.device { %device = hal.ex.shared_device : !hal.device return %device : !hal.device } - -// ----- - -// CHECK-LABEL: @submit_and_wait -func.func @submit_and_wait() { - %0 = "test_hal.device"() : () -> !hal.device - %1 = "test_hal.command_buffer"() : () -> !hal.command_buffer - // CHECK: hal.ex.submit_and_wait %0, %1 - hal.ex.submit_and_wait %0, %1 - return -} diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir index 595d06cdcfe8..de8cedb200c4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir @@ -1,5 +1,14 @@ // RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s +// CHECK-LABEL: @timeline_advance +func.func @timeline_advance() -> !hal.fence { + // CHECK: = hal.timeline.advance : !hal.fence + %fence = hal.timeline.advance : !hal.fence + return %fence : !hal.fence +} + +// ----- + // CHECK-LABEL: @fence_create func.func @fence_create(%arg0: !hal.semaphore, %arg1: i64, %arg2: i64) -> !hal.fence { // CHECK: = hal.fence.create diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp index 8c06c2763b24..e38ef9e556f7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp @@ -168,6 +168,10 @@ class CUDATargetBackend final : public TargetBackend { Builder b(context); SmallVector configItems; + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp index 9eed32dddee4..abf87a375ff0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp @@ -167,6 +167,10 @@ class LLVMCPUTargetBackend final : public TargetBackend { Builder b(context); SmallVector configItems; + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp index 390eb8e2b917..5796680b47d0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp @@ -54,6 +54,10 @@ class MetalSPIRVTargetBackend : public TargetBackend { Builder b(context); SmallVector configItems; + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp index 83cfb709b920..c15a1a15f20b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp @@ -85,6 +85,11 @@ class ROCMTargetBackend final : public TargetBackend { MLIRContext *context) const override { Builder b(context); SmallVector configItems; + + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp index 12e87ae1cc1b..3402c09f7ccd 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp @@ -45,6 +45,10 @@ class VMVXTargetBackend final : public TargetBackend { Builder b(context); SmallVector configItems; + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index 68998539f529..afbcc6082e8a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -126,6 +126,10 @@ class VulkanSPIRVTargetBackend : public TargetBackend { Builder b(context); SmallVector configItems; + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp index 03ff3caea610..ee7bc0dc1d3d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp @@ -71,6 +71,10 @@ class WebGPUTargetBackend : public TargetBackend { Builder b(context); SmallVector configItems; + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD index cb3104106dde..68cf82238693 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD @@ -21,10 +21,12 @@ iree_compiler_cc_library( "DumpExecutableBenchmarks.cpp", "DumpExecutableSources.cpp", "ElideRedundantCommands.cpp", + "FixupLegacySync.cpp", "InlineDeviceSwitches.cpp", "LinkExecutables.cpp", "MaterializeInterfaces.cpp", "MaterializeResourceCaches.cpp", + "MaterializeTimelines.cpp", "MemoizeDeviceQueries.cpp", "Passes.cpp", "ResolveExportOrdinals.cpp", @@ -61,6 +63,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 9cc0012d5e6f..c59c4cc1270e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -22,10 +22,12 @@ iree_cc_library( "DumpExecutableBenchmarks.cpp" "DumpExecutableSources.cpp" "ElideRedundantCommands.cpp" + "FixupLegacySync.cpp" "InlineDeviceSwitches.cpp" "LinkExecutables.cpp" "MaterializeInterfaces.cpp" "MaterializeResourceCaches.cpp" + "MaterializeTimelines.cpp" "MemoizeDeviceQueries.cpp" "Passes.cpp" "ResolveExportOrdinals.cpp" @@ -42,6 +44,7 @@ iree_cc_library( MLIRIR MLIRPass MLIRSCFDialect + MLIRSCFToControlFlow MLIRSupport MLIRTransforms iree::compiler::Dialect::Flow::IR diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 9c7fa988c5bb..65c72e3b4e8d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -312,9 +312,27 @@ static void appendDispatchBenchmark(IREE::HAL::ExecutableOp executableOp, forBuilder.create(loc); }); - // Submit command buffer. funcBuilder.create(loc, commandBuffer); - funcBuilder.create(loc, device, commandBuffer); + + // We begin executing immediately and then wait on a fence. + // TODO(benvanik): add fences to ABI so the benchmark tool can pipeline. + Value waitFence = funcBuilder.create( + loc, funcBuilder.getType()); + Value signalFence = funcBuilder.create( + loc, funcBuilder.getType()); + + // Queue execution. + auto queueAffinity = funcBuilder.create(loc, -1, 64); + funcBuilder.create( + loc, device, queueAffinity, waitFence, signalFence, + ValueRange{commandBuffer}); + + // Block until it completes. + Value timeoutMillis = funcBuilder.create(loc, -1, 32); + auto fenceOp = funcBuilder.create( + loc, funcBuilder.getI32Type(), timeoutMillis, signalFence); + funcBuilder.create( + loc, fenceOp.getStatus(), "failed to wait on timepoint"); funcBuilder.create(loc); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp new file mode 100644 index 000000000000..1f9e59c63c35 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp @@ -0,0 +1,199 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { + +// Marks a command buffer as being executable inline during recording. +// This is only possible because we generate our command buffer code without +// caching today and know that all are executable inline so long as we have +// blocking queue operations. As soon as we memoize command buffers this will be +// invalid. +static void makeAllowInlineExecution(IREE::HAL::CommandBufferCreateOp op) { + auto modes = op.getModes(); + if (bitEnumContains(modes, IREE::HAL::CommandBufferModeBitfield::OneShot)) { + op.setModesAttr(IREE::HAL::CommandBufferModeBitfieldAttr::get( + op.getContext(), + modes | IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution)); + } +} + +// Scans backward/forward from |asyncOp| and converts it to blocking form by +// waiting on the wait fences and signal fences if needed. +// We allow any number of non-side-effecting ops to exist between the search +// point and where the waits will be as often times arith ops end up scattered +// around. +// +// Example: +// hal.fence.await until([%wait_fence]) // existing +// // no wait inserted on %wait_fence as present preceeding: +// hal.device.queue.execute wait(%wait_fence) signal(%signal_fence) +// // no wait inserted on %signal_fence as present following: +// hal.fence.await until([%signal_fence]) // existing +static void insertWaitIfNeeded(Operation *asyncOp, + MutableOperandRange waitFence, + Value signalFence) { + assert(waitFence.size() == 1 && "one wait fence expected"); + auto loc = asyncOp->getLoc(); + + // Returns true if waits can be reordered across |op|. + auto isSafeToReorder = [&](Operation &op) { + // For now we just ignore arith ops and constants. + // I hope we can delete this pass before we need more :) + return op.hasTrait() || + op.getDialect()->getNamespace() == "arith"; + }; + + // Returns an operation waiting on |fence| that is guaranteed to have + // executed prior to asyncOp. Returns null if no waits found. + auto beginIt = std::prev(asyncOp->getBlock()->begin()); + auto endIt = std::prev(asyncOp->getBlock()->end()); // ignore terminator + auto findPrecedingAwait = [&](Value fence) -> Operation * { + auto it = std::prev(Block::iterator(asyncOp)); + for (; it != beginIt; --it) { + if (auto awaitOp = dyn_cast(it)) { + if (llvm::is_contained(awaitOp.getFences(), fence)) { + // Wait is for the fence, found! + return &*it; + } else { + // Keep scanning - generally waiting on one fence is enough. + continue; + } + } else if (!isSafeToReorder(*it)) { + break; // hit a point we can't scan past + } + } + return nullptr; + }; + + // Returns an operation waiting on |fence| that is guaranteed to be + // executed after asyncOp. Returns null if no waits found. + auto findSucceedingAwait = [&](Value fence) -> Operation * { + auto it = std::next(Block::iterator(asyncOp)); + for (; it != endIt; ++it) { + if (auto awaitOp = dyn_cast(it)) { + if (llvm::is_contained(awaitOp.getFences(), fence)) { + // Wait is for the fence, found! + return &*it; + } else { + // Keep scanning - generally waiting on one fence is enough. + continue; + } + } else if (!isSafeToReorder(*it)) { + break; // hit a point we can't scan past + } + } + return nullptr; + }; + + OpBuilder builder(asyncOp); + Value timeoutMillis; + auto makeInfiniteTimeout = [&]() { + if (timeoutMillis) return timeoutMillis; + timeoutMillis = builder.create(loc, -1, 32); + return timeoutMillis; + }; + + // Scan backward to see if the wait fences have been signaled already. + // Since we walk the regions forward we will likely have a wait from the + // producer already. + auto *precedingAwait = findPrecedingAwait(waitFence[0]); + if (!precedingAwait) { + builder.create( + loc, builder.getI32Type(), makeInfiniteTimeout(), waitFence[0]); + } + if (!isa_and_nonnull(waitFence[0].getDefiningOp())) { + // Neuter wait because it's either covered (we found a preceding await) or + // we just inserted one. + Value nullFence = builder.create( + loc, builder.getType()); + waitFence.assign(nullFence); + } + + // Scan forward to see if the signal fences are waited on already. + auto *succeedingAwait = findSucceedingAwait(signalFence); + if (!succeedingAwait) { + builder.setInsertionPointAfter(asyncOp); + builder.create(loc, builder.getI32Type(), + makeInfiniteTimeout(), signalFence); + } +} + +// NOTE: this pass only exists for backwards compatibility with legacy HAL +// drivers. It will be removed once all have migrated to the modern async APIs. +struct FixupLegacySyncPass + : public PassWrapper> { + StringRef getArgument() const override { + return "iree-hal-fixup-legacy-sync"; + } + + StringRef getDescription() const override { + return "Applies fixups to the program for when using legacy HAL devices " + "that only support synchronous execution"; + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // See if any devices are marked as requiring the legacy_sync behavior. + // If any single device does we must uniformly apply the fixups. + bool anyRequireFixup = false; + auto deviceTargetAttrs = IREE::HAL::DeviceTargetAttr::lookup(moduleOp); + for (auto deviceTargetAttr : deviceTargetAttrs) { + if (deviceTargetAttr.hasConfigurationAttr("legacy_sync")) { + anyRequireFixup = true; + break; + } + } + if (!anyRequireFixup) return; + + // This could use an interface but it'd be better to remove the need for + // this pass instead. + for (auto funcOp : moduleOp.getOps()) { + funcOp.walk([&](Operation *op) { + TypeSwitch(op) + .Case([&](IREE::HAL::CommandBufferCreateOp op) { + makeAllowInlineExecution(op); + }) + .Case([&](IREE::HAL::DeviceQueueAllocaOp op) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + }) + .Case([&](IREE::HAL::DeviceQueueDeallocaOp op) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + }) + .Case([&](IREE::HAL::DeviceQueueExecuteOp op) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + }); + }); + } + } +}; + +std::unique_ptr> createFixupLegacySyncPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTimelines.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTimelines.cpp new file mode 100644 index 000000000000..96edf34b4e16 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTimelines.cpp @@ -0,0 +1,147 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { +namespace { + +//===----------------------------------------------------------------------===// +// hal.timeline analysis +//===----------------------------------------------------------------------===// + +// This pass is provisional and only works because we have a single device and +// don't do multi-queue scheduling. When we want to do that we'll need to attach +// device information to each `hal.timeline.advance` or have it take a device +// SSA value. We may also want a top-level timeline type we insert before +// lowering streams to hal - possibly even in the stream dialect as a final +// stage. + +struct Timeline { + IREE::Util::GlobalOp semaphore; + IREE::Util::GlobalOp value; +}; + +static Timeline defineGlobalTimeline(mlir::ModuleOp moduleOp) { + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + // When we support multiple devices and queues we'd want to name the globals + // based on them and use their canonical location information (maybe all + // places that touch the timeline). + Timeline timeline; + std::string namePrefix = "_timeline"; + auto loc = moduleBuilder.getUnknownLoc(); + + // Internal timelines start at zero. + auto initialValueAttr = moduleBuilder.getI64IntegerAttr(0); + + timeline.semaphore = moduleBuilder.create( + loc, namePrefix + "_semaphore", /*isMutable=*/false, + moduleBuilder.getType()); + timeline.semaphore.setPrivate(); + auto initializerOp = moduleBuilder.create(loc); + auto initializerBuilder = + OpBuilder::atBlockBegin(initializerOp.addEntryBlock()); + Value device = initializerBuilder.create(loc); + Value initialValue = + initializerBuilder.create(loc, initialValueAttr); + auto semaphore = initializerBuilder.create( + loc, initializerBuilder.getType(), device, + initialValue); + initializerBuilder.create(loc, semaphore, + timeline.semaphore); + initializerBuilder.create(loc); + + timeline.value = moduleBuilder.create( + loc, namePrefix + "_value", /*isMutable=*/true, + moduleBuilder.getI64Type(), initialValueAttr); + timeline.value.setPrivate(); + + return timeline; +} + +static void rewriteTimelineOps(Timeline timeline, mlir::ModuleOp rootOp) { + for (auto funcOp : rootOp.getOps()) { + funcOp.walk([&](IREE::HAL::TimelineAdvanceOp advanceOp) { + auto builder = OpBuilder(advanceOp); + Value semaphore = builder.create( + advanceOp.getLoc(), timeline.semaphore); + Value currentValue = builder.create( + advanceOp.getLoc(), timeline.value); + Value one = + builder.create(advanceOp.getLoc(), 1, 64); + Value nextValue = + builder.create(advanceOp.getLoc(), currentValue, one); + builder.create(advanceOp.getLoc(), nextValue, + timeline.value); + Value fence = builder.create( + advanceOp.getLoc(), builder.getType(), + ValueRange{semaphore}, ValueRange{nextValue}); + advanceOp.replaceAllUsesWith(fence); + advanceOp.erase(); + }); + } +} + +//===----------------------------------------------------------------------===// +// -iree-hal-materialize-timelines +//===----------------------------------------------------------------------===// + +class MaterializeTimelinesPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeTimelinesPass) + + MaterializeTimelinesPass() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + StringRef getArgument() const override { + return "iree-hal-materialize-timelines"; + } + + StringRef getDescription() const override { + return "Materializes timelines for device queues."; + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + auto timeline = defineGlobalTimeline(moduleOp); + rewriteTimelineOps(timeline, moduleOp); + } +}; + +} // namespace + +std::unique_ptr> createMaterializeTimelinesPass() { + return std::make_unique(); +} + +static PassRegistration pass([] { + return std::make_unique(); +}); + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 455c5c1dbe5e..64fdb35c407a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Passes.h" @@ -147,6 +148,14 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // Convert supported input dialects (std, stream, etc) into the HAL dialect. passManager.addPass(createConvertToHALPass()); + + // Materialize timelines for device queues. + passManager.addPass(createMaterializeTimelinesPass()); + + // If any devices require the legacy synchronous execution behavior then + // make all async operations blocking. + passManager.addPass(createFixupLegacySyncPass()); + addCleanupPatterns(passManager); //---------------------------------------------------------------------------- @@ -198,7 +207,11 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // Fixup workgroup count calculations that may have used the affine dialect. // Kind of random here but can happen if the benchmarking code does things. - passManager.addPass(createLowerAffinePass()); + passManager.addPass(mlir::createLowerAffinePass()); + + // TODO(benvanik): remove the need for this; some cleanup passes such as + // SimplifyGlobalAccesses are currently broken with scf present. + FunctionLikeNest(passManager).addPass(mlir::createConvertSCFToCFPass); // Combine the initializers we emitted during resource cache materialization. passManager.addPass(IREE::Util::createCombineInitializersPass()); @@ -218,7 +231,7 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // NOTE: symbol DCE will destroy executable target contents, so only run it // if we serialized things. - passManager.addPass(createSymbolDCEPass()); + passManager.addPass(mlir::createSymbolDCEPass()); } } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h index c88097333aca..6805b22163b0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h @@ -53,6 +53,9 @@ void registerHALTransformPassPipeline(); // Converts input flow/std/etc dialects to the IREE HAL dialect. std::unique_ptr> createConvertToHALPass(); +// Materializes timelines for device queues. +std::unique_ptr> createMaterializeTimelinesPass(); + //===----------------------------------------------------------------------===// // Device management //===----------------------------------------------------------------------===// @@ -67,6 +70,11 @@ createVerifyTargetEnvironmentPass(); std::unique_ptr> createAssignTargetDevicesPass( ArrayRef targets); +// Applies fixups to the program for when using legacy HAL devices that only +// support synchronous execution. Once all devices support async this will be +// removed. +std::unique_ptr> createFixupLegacySyncPass(); + // Outlines hal.device.switch conditions into functions and inlines conditions. std::unique_ptr> createInlineDeviceSwitchesPass(); @@ -162,10 +170,12 @@ inline void registerHALPasses() { createDumpExecutableSourcesPass(""); createElideRedundantCommandsPass(); createInlineDeviceSwitchesPass(); + createFixupLegacySyncPass(); createLinkExecutablesPass(); createLinkTargetExecutablesPass(""); createMaterializeInterfacesPass(); createMaterializeResourceCachesPass(targetOptions); + createMaterializeTimelinesPass(); createMemoizeDeviceQueriesPass(); createResolveExportOrdinalsPass(); createSerializeExecutablesPass(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD index 37cff8c44114..3f57d74e7684 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD @@ -23,9 +23,11 @@ iree_lit_test_suite( "dump_executable_benchmarks.mlir", "dump_executable_sources.mlir", "elide_redundant_commands.mlir", + "fixup_legacy_sync.mlir", "inline_device_switches.mlir", "materialize_interfaces.mlir", "materialize_resource_caches.mlir", + "materialize_timelines.mlir", "memoize_device_queries.mlir", "resolve_export_ordinals.mlir", "verify_target_environment.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index 510fa227e1b3..6b8e6a2dfce7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -20,9 +20,11 @@ iree_lit_test_suite( "dump_executable_benchmarks.mlir" "dump_executable_sources.mlir" "elide_redundant_commands.mlir" + "fixup_legacy_sync.mlir" "inline_device_switches.mlir" "materialize_interfaces.mlir" "materialize_resource_caches.mlir" + "materialize_timelines.mlir" "memoize_device_queries.mlir" "resolve_export_ordinals.mlir" "verify_target_environment.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 207e5a479319..1aa0fae9faf1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -115,7 +115,14 @@ module attributes {hal.device.targets = [#device_target_cpu]} { // CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer> } => !stream.timepoint - // CHECK: hal.ex.submit_and_wait %[[DEVICE]], %[[CMD]] + // CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence + // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance + // CHECK: hal.device.queue.execute<%[[DEVICE]] + // CHECK-SAME: wait(%[[WAIT_FENCE]]) + // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) + // CHECK-SAME: commands([%[[CMD]]]) + + // CHECK: hal.fence.await until([%[[SIGNAL_FENCE]]]) %result_ready = stream.timepoint.await %timepoint => %result_resource : !stream.resource{%c16} // CHECK: %[[RESULT_VIEW:.+]] = hal.buffer_view.create diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir index a862620835f3..d80cdb3cbb19 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir @@ -90,7 +90,7 @@ module attributes {hal.device.targets = [#device_target_cpu]} { // Submit and wait for dispatches to complete: // CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer> - // CHECK: hal.ex.submit_and_wait %{{.+}}, %[[CMD]] + // CHECK: hal.fence.await // =========================================================================== // @dispatch1 benchmark logic (note two deduplicated dispatches): diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir new file mode 100644 index 000000000000..2a9091ff59b3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir @@ -0,0 +1,95 @@ +// RUN: iree-opt --split-input-file --iree-hal-fixup-legacy-sync %s | FileCheck %s + +// Tests that command buffers that are reusable don't execute inline. +// Reusable + inline is not a valid combination. + +module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +// CHECK-LABEL: @command_buffer_reusable +func.func @command_buffer_reusable(%arg0: !hal.device) { + // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode("None") + %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("None") categories("Transfer|Dispatch") : !hal.command_buffer + return +} +} // module + +// ----- + +// Tests that one-shot command buffers are allowed to execute inline. + +module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +// CHECK-LABEL: @command_buffer_oneshot +func.func @command_buffer_oneshot(%arg0: !hal.device) { + // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot|AllowInlineExecution") + %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) categories("Transfer|Dispatch") : !hal.command_buffer + return +} +} // module + +// ----- + +// Tests for a no-op if there are no devices requiring legacy mode. + +module attributes {hal.device.targets = [ + #hal.device.target<"vmvx", {}>, + #hal.device.target<"vulkan", {}> +]} { +// CHECK-LABEL: @legacy_mode_not_required +func.func @legacy_mode_not_required(%arg0: !hal.device) { + // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) + %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) categories("Transfer|Dispatch") : !hal.command_buffer + return +} +} // module + +// ----- + +// Tests that queued operations get the appropriate waits before/after. + +module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +// CHECK-LABEL: @blocking_execute +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence) +func.func @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { + %affinity = arith.constant 0 : i64 + // CHECK-DAG: %[[NULL:.+]] = util.null : !hal.fence + // CHECK-DAG: hal.fence.await until([%[[WAIT]]]) + // CHECK-NEXT: hal.device.queue.execute<%[[DEVICE]] : !hal.device> + // CHECK-SAME: wait(%[[NULL]]) signal(%[[SIGNAL]]) + // CHECK-SAME: commands([%[[CMD]]]) + // CHECK-NEXT: hal.fence.await until([%[[SIGNAL]]]) + hal.device.queue.execute<%device : !hal.device> + affinity(%affinity) + wait(%wait) signal(%signal) + commands([%cmd]) + return +} +} // module + +// ----- + +// Tests that waits are not inserted if they already exist. + +module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +// CHECK-LABEL: @blocking_execute +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence) +func.func @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { + // CHECK-NEXT: %[[TIMEOUT:.+]] = arith.constant 100 + %timeout = arith.constant 100 : i32 + // CHECK-NEXT: hal.fence.await until([%[[WAIT]]]) timeout_millis(%[[TIMEOUT]]) + hal.fence.await until([%wait]) timeout_millis(%timeout) : i32 + // This should not block the search: + // CHECK-NEXT: arith.constant 0 + %affinity = arith.constant 0 : i64 + // CHECK-NEXT: %[[NULL:.+]] = util.null : !hal.fence + // CHECK-NEXT: hal.device.queue.execute<%[[DEVICE]] : !hal.device> + // CHECK-SAME: wait(%[[NULL]]) signal(%[[SIGNAL]]) + // CHECK-SAME: commands([%[[CMD]]]) + hal.device.queue.execute<%device : !hal.device> + affinity(%affinity) + wait(%wait) signal(%signal) + commands([%cmd]) + // CHECK-NEXT: hal.fence.await until([%[[SIGNAL]]]) timeout_millis(%[[TIMEOUT]]) + hal.fence.await until([%signal]) timeout_millis(%timeout) : i32 + // CHECK-NEXT: return + return +} +} // module diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_timelines.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_timelines.mlir new file mode 100644 index 000000000000..b5a03bd1023c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_timelines.mlir @@ -0,0 +1,44 @@ +// RUN: iree-opt --split-input-file --iree-hal-materialize-timelines %s | FileCheck %s + +// CHECK: util.global private @_timeline_semaphore : !hal.semaphore +// CHECK: util.initializer { +// CHECK: %[[DEVICE:.+]] = hal.ex.shared_device +// CHECK: %[[SEMAPHORE:.+]] = hal.semaphore.create +// CHECK-SAME: device(%[[DEVICE]] : !hal.device) +// CHECK-SAME: initial(%c0_i64) +// CHECK-NEXT: util.global.store %[[SEMAPHORE]], @_timeline_semaphore +// CHECK: } + +// CHECK: util.global private mutable @_timeline_value = 0 : i64 + +// CHECK-LABEL: @fn1 +func.func @fn1() -> !hal.fence { + // CHECK: %[[SEMAPHORE:.+]] = util.global.load @_timeline_semaphore + // CHECK: %[[CURRENT_VALUE:.+]] = util.global.load @_timeline_value + // CHECK: %[[NEXT_VALUE:.+]] = arith.addi %[[CURRENT_VALUE]], %c1 + // CHECK: util.global.store %[[NEXT_VALUE]], @_timeline_value + // CHECK: %[[FENCE0:.+]] = hal.fence.create at<%[[SEMAPHORE]] : !hal.semaphore>(%[[NEXT_VALUE]]) + %0 = hal.timeline.advance : !hal.fence + // CHECK: return %[[FENCE0]] + return %0 : !hal.fence +} + +// CHECK-LABEL: @fn2 +func.func @fn2(%arg0: i1, %arg1: !hal.fence) -> !hal.fence { + // CHECK: %[[FENCE:.+]] = scf.if + %0 = scf.if %arg0 -> (!hal.fence) { + // CHECK: scf.yield %arg1 + scf.yield %arg1 : !hal.fence + } else { + // CHECK: %[[SEMAPHORE:.+]] = util.global.load @_timeline_semaphore + // CHECK: %[[CURRENT_VALUE:.+]] = util.global.load @_timeline_value + // CHECK: %[[NEXT_VALUE:.+]] = arith.addi %[[CURRENT_VALUE]], %c1 + // CHECK: util.global.store %[[NEXT_VALUE]], @_timeline_value + // CHECK: %[[NEW_FENCE:.+]] = hal.fence.create at<%[[SEMAPHORE]] : !hal.semaphore>(%[[NEXT_VALUE]]) + %1 = hal.timeline.advance : !hal.fence + // CHECK: scf.yield %[[NEW_FENCE]] + scf.yield %1 : !hal.fence + } + // CHECK: return %[[FENCE]] + return %0 : !hal.fence +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 9e235a24fd80..497c1a065779 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -308,6 +308,51 @@ vm.import @device.query.i64( ) -> (i32, i64) attributes {nosideeffects} +// Returns a queue-ordered transient buffer that will be available for use when +// the signal fence is reached. The allocation will not be made until the +// wait fence has been reached. +vm.import @device.queue.alloca( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %pool : i32, + %memory_types : i32, + %buffer_usage : i32, + %allocation_size : i64 +) -> !vm.ref + +// Deallocates a queue-ordered transient buffer. +// The deallocation will not be made until the wait fence has been reached and +// once the storage is available for reuse the signal fence will be signaled. +vm.import @device.queue.dealloca( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %buffer : !vm.ref +) + +// Executes one or more command buffers on a device queue. +// The command buffers are executed in order as if they were recorded as one. +// No commands will execute until the wait fence has been reached and the signal +// fence will be signaled when all commands have completed. +vm.import @device.queue.execute( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %command_buffers : !vm.ref... +) + +// Flushes any locally-pending submissions in the queue. +// When submitting many queue operations this can be used to eagerly flush +// earlier submissions while later ones are still being constructed. +vm.import @device.queue.flush( + %device : !vm.ref, + %queue_affinity : i64 +) + //===----------------------------------------------------------------------===// // iree_hal_executable_t //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index 8381449ee63b..a6f97edd4ab5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -129,7 +129,7 @@ class AbstractResourceUsage return convertBitsToResourceUsage(this->getAssumed()); } - const std::string getAsStr() const override { + const std::string getAsStr(AsmState &asmState) const override { std::string str; auto append = [&](const char *part) { if (!str.empty()) str += '|'; diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td index a9b87469d9d9..9809d250185a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td @@ -156,17 +156,25 @@ def Stream_TimelineOp : OpInterface<"TimelineOpInterface"> { }]; let methods = [ - // TODO(benvanik): implement interface methods: - // getWaitTimepoints - // setWaitTimepoints - // getWaitResources - // getSignalTimepoint - // setSignalTimepoint - // getSignalResources - // + maybe mutable resource accessors? (MutableOperandRange) - // This would let us rework code relying on AsyncExecuteOp/CmdExecuteOp to - // work with both, and wait elision canonicalization patterns to be shared - // across the async resource ops and execution ops. + InterfaceMethod< + /*desc=*/[{ + Returns zero or more timepoints consumed by this timeline operation + indicating the asynchronous operations that must complete before it can + perform its operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getAwaitTimepoints", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Returns the timepoint produced by this timeline operation indicating the + asynchronous completion of the operation. + }], + /*retTy=*/"Value", + /*methodName=*/"getResultTimepoint", + /*args=*/(ins) + >, ]; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index c0187600269b..d08f7173e849 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -268,6 +268,93 @@ struct TieRegionResults : public OpRewritePattern { } }; +// Adds await dependencies on |newTimepoints| to the op with an optional +// |existingTimepoint| by possibly producing a new timepoint to await. +// This may just pass through the provided timepoint or create a join based on +// the existing await behavior of the op and the new values. +static Value joinAwaitTimepoints(Location loc, Value existingTimepoint, + ArrayRef newTimepoints, + OpBuilder &builder) { + if (newTimepoints.empty()) { + // No new timepoints - preserve existing. + return existingTimepoint; + } else if (newTimepoints.size() == 1 && !existingTimepoint) { + // Adding a single new timepoint. + return newTimepoints.front(); + } + + // Materialize a join of the new timepoints + the existing (if present). + SmallVector joinTimepoints; + if (existingTimepoint) { + joinTimepoints.push_back(existingTimepoint); + } + llvm::append_range(joinTimepoints, newTimepoints); + return builder.create( + loc, builder.getType(), joinTimepoints); +} + +// Elides waits that are known to be immediately resolved. +// +// Example: +// %0 = stream.timepoint.immediate +// %1 = stream.resource.alloca await(%0) ... +// -> +// %1 = stream.resource.alloca ... +template +struct ElideImmediateTimepointWait : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Op op, + PatternRewriter &rewriter) const override { + bool isImmediate = + op.getAwaitTimepoint() && isa_and_nonnull( + op.getAwaitTimepoint().getDefiningOp()); + if (!isImmediate) return failure(); + rewriter.updateRootInPlace( + op, [&]() { op.getAwaitTimepointMutable().clear(); }); + return success(); + } +}; + +// Chains operand resources produced by an await to dependent execution regions. +// This elides host waits and allows for device-side wait resolution. +// +// Example: +// %0 = stream.cmd.execute with(%resource) +// %1 = stream.timepoint.await %0 => %resource +// %2 = stream.cmd.execute with(%resource) +// -> +// %0 = stream.cmd.execute with(%resource) +// %2 = stream.cmd.execute await(%0) => with(%resource) +template +struct ChainDependentAwaits : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Op op, + PatternRewriter &rewriter) const override { + SmallVector newTimepoints; + SmallVector> replacements; + for (auto operand : llvm::enumerate(op.getResourceOperands())) { + if (auto awaitOp = + operand.value().template getDefiningOp()) { + newTimepoints.push_back(awaitOp.getAwaitTimepoint()); + replacements.push_back(std::make_pair( + operand.index(), awaitOp.getTiedResultOperand(operand.value()))); + } + } + if (replacements.empty()) return failure(); + rewriter.updateRootInPlace(op, [&]() { + auto newTimepoint = joinAwaitTimepoints( + op.getLoc(), op.getAwaitTimepoint(), newTimepoints, rewriter); + op.getAwaitTimepointMutable().assign(newTimepoint); + for (auto replacement : replacements) { + op.getResourceOperandsMutable() + .slice(replacement.first, 1) + .assign(replacement.second); + } + }); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -287,6 +374,7 @@ void ResourceAllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { // TODO(benvanik): sink to first user. // TODO(benvanik): elide if only user is dealloc. + results.insert>(context); } //===----------------------------------------------------------------------===// @@ -296,6 +384,7 @@ void ResourceAllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, void ResourceDeallocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { // TODO(benvanik): move up to producer of timepoint. + results.insert>(context); } //===----------------------------------------------------------------------===// @@ -1548,95 +1637,8 @@ void AsyncDispatchOp::getCanonicalizationPatterns(RewritePatternSet &results, // stream.async.execute //===----------------------------------------------------------------------===// -// Adds await dependencies on |newTimepoints| to the op with an optional -// |existingTimepoint| by possibly producing a new timepoint to await. -// This may just pass through the provided timepoint or create a join based on -// the existing await behavior of the op and the new values. -static Value joinAwaitTimepoints(Location loc, Value existingTimepoint, - ArrayRef newTimepoints, - OpBuilder &builder) { - if (newTimepoints.empty()) { - // No new timepoints - preserve existing. - return existingTimepoint; - } else if (newTimepoints.size() == 1 && !existingTimepoint) { - // Adding a single new timepoint. - return newTimepoints.front(); - } - - // Materialize a join of the new timepoints + the existing (if present). - SmallVector joinTimepoints; - if (existingTimepoint) { - joinTimepoints.push_back(existingTimepoint); - } - llvm::append_range(joinTimepoints, newTimepoints); - return builder.create( - loc, builder.getType(), joinTimepoints); -} - namespace { -// Elides waits that are known to be immediately resolved. -// -// Example: -// %0 = stream.timepoint.immediate -// %1 = stream.async.execute await(%0) => with(...) -// -> -// %1 = stream.async.execute with(...) -struct ElideImmediateAsyncExecuteWaits - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AsyncExecuteOp op, - PatternRewriter &rewriter) const override { - bool isImmediate = - op.getAwaitTimepoint() && isa_and_nonnull( - op.getAwaitTimepoint().getDefiningOp()); - if (!isImmediate) return failure(); - rewriter.updateRootInPlace( - op, [&]() { op.getAwaitTimepointMutable().clear(); }); - return success(); - } -}; - -// If any operands are sourced from subviews clone those subviews into the -// region and rewrite the operands to point at the original resource. This -// allows us to progressively fold the subviews into the ops consuming them. -// -// Example: -// %0 = stream.resource.subview %src[%offset] ... -// %1 = stream.async.execute with(%0 as %arg0) -// -> -// %1 = stream.async.execute with(%src as %arg0) { -// %2 = stream.resource.subview %arg0[%offset] ... -// } -struct ChainAsyncExecuteWaits : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AsyncExecuteOp op, - PatternRewriter &rewriter) const override { - SmallVector newTimepoints; - SmallVector> replacements; - for (auto operand : llvm::enumerate(op.getResourceOperands())) { - if (auto awaitOp = operand.value().getDefiningOp()) { - newTimepoints.push_back(awaitOp.getAwaitTimepoint()); - replacements.push_back(std::make_pair( - operand.index(), awaitOp.getTiedResultOperand(operand.value()))); - } - } - if (replacements.empty()) return failure(); - rewriter.updateRootInPlace(op, [&]() { - auto newTimepoint = joinAwaitTimepoints( - op.getLoc(), op.getAwaitTimepoint(), newTimepoints, rewriter); - op.getAwaitTimepointMutable().assign(newTimepoint); - - for (auto replacement : replacements) { - op.getResourceOperandsMutable() - .slice(replacement.first, 1) - .assign(replacement.second); - } - }); - return success(); - } -}; - // If any operands are sourced from subviews clone those subviews into the // region and rewrite the operands to point at the original resource. This // allows us to progressively fold the subviews into the ops consuming them. @@ -1725,8 +1727,8 @@ struct ElideNoOpAsyncExecuteOp : public OpRewritePattern { void AsyncExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); - results.insert(context); + results.insert>(context); + results.insert>(context); results.insert(context); results.insert(context); results.insert>( @@ -2033,65 +2035,6 @@ void CmdDispatchOp::getCanonicalizationPatterns(RewritePatternSet &results, namespace { -// Elides waits that are known to be immediately resolved. -// -// Example: -// %0 = stream.timepoint.immediate -// %1 = stream.cmd.execute await(%0) => with(...) -// -> -// %1 = stream.cmd.execute with(...) -struct ElideImmediateCmdExecuteWaits : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CmdExecuteOp op, - PatternRewriter &rewriter) const override { - bool isImmediate = - op.getAwaitTimepoint() && isa_and_nonnull( - op.getAwaitTimepoint().getDefiningOp()); - if (!isImmediate) return failure(); - rewriter.updateRootInPlace( - op, [&]() { op.getAwaitTimepointMutable().clear(); }); - return success(); - } -}; - -// Chains operand resources produced by an await to dependent execution regions. -// This elides host waits and allows for device-side wait resolution. -// -// Example: -// %0 = stream.cmd.execute with(%resource) -// %1 = stream.timepoint.await %0 => %resource -// %2 = stream.cmd.execute with(%resource) -// -> -// %0 = stream.cmd.execute with(%resource) -// %2 = stream.cmd.execute await(%0) => with(%resource) -struct ChainCmdExecuteWaits : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CmdExecuteOp op, - PatternRewriter &rewriter) const override { - SmallVector newTimepoints; - SmallVector> replacements; - for (auto operand : llvm::enumerate(op.getResourceOperands())) { - if (auto awaitOp = operand.value().getDefiningOp()) { - newTimepoints.push_back(awaitOp.getAwaitTimepoint()); - replacements.push_back(std::make_pair( - operand.index(), awaitOp.getTiedResultOperand(operand.value()))); - } - } - if (replacements.empty()) return failure(); - rewriter.updateRootInPlace(op, [&]() { - auto newTimepoint = joinAwaitTimepoints( - op.getLoc(), op.getAwaitTimepoint(), newTimepoints, rewriter); - op.getAwaitTimepointMutable().assign(newTimepoint); - for (auto replacement : replacements) { - op.getResourceOperandsMutable() - .slice(replacement.first, 1) - .assign(replacement.second); - } - }); - return success(); - } -}; - // If any operands are sourced from subviews clone those subviews into the // region and rewrite the operands to point at the original resource. This // allows us to progressively fold the subviews into the ops consuming them. @@ -2174,8 +2117,8 @@ struct ElideNoOpCmdExecuteOp : public OpRewritePattern { void CmdExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); - results.insert(context); + results.insert>(context); + results.insert>(context); results.insert(context); results.insert(context); results.insert>(context); @@ -2366,7 +2309,7 @@ LogicalResult TimepointAwaitOp::fold(ArrayRef foldOperands, namespace { -struct ElideImmediateAwaits : public OpRewritePattern { +struct ElideImmediateHostAwaits : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TimepointAwaitOp op, PatternRewriter &rewriter) const override { @@ -2601,7 +2544,7 @@ struct FoldDuplicateAwaitResources : public OpRewritePattern { void TimepointAwaitOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { // TODO(benvanik): elide waits if timepoint must be satisfied in use-def. - results.insert(context); + results.insert(context); results.insert(context); results.insert(context); results.insert(context); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index b367098dbcd6..59b1aa659b9f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -122,6 +122,9 @@ def Stream_ResourceAllocaOp : Stream_PureOp<"resource.alloca", [ let extraClassDeclaration = [{ Value getOperandSize(unsigned idx) { return {}; } Value getResultSize(unsigned idx) { return getStorageSize(); } + SmallVector getAwaitTimepoints() { + if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; + } }]; let hasCanonicalizer = 1; @@ -170,6 +173,9 @@ def Stream_ResourceDeallocaOp : Stream_Op<"resource.dealloca", [ let extraClassDeclaration = [{ Value getOperandSize(unsigned idx) { return getOperandSize(); } Value getResultSize(unsigned idx) { return {}; } + SmallVector getAwaitTimepoints() { + if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; + } }]; let hasCanonicalizer = 1; @@ -531,6 +537,7 @@ def Stream_ResourceConstantsOp : Stream_PureOp<"resource.constants", [ let extraClassDeclaration = [{ Value getOperandSize(unsigned idx) { return {}; } Value getResultSize(unsigned idx) { return getResultSizes()[idx]; } + SmallVector getAwaitTimepoints() { return {}; } }]; } @@ -1975,6 +1982,9 @@ def Stream_AsyncExecuteOp : Stream_Op<"async.execute", [ Value getResultSize(unsigned idx) { return findValueSizeInList(idx, getResults(), getResultSizes()); } + SmallVector getAwaitTimepoints() { + if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; + } }]; let hasVerifier = 1; @@ -2417,6 +2427,9 @@ def Stream_CmdExecuteOp : Stream_Op<"cmd.execute", [ Value getResultSize(unsigned idx) { return {}; } + SmallVector getAwaitTimepoints() { + if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; + } }]; let hasVerifier = 1; @@ -2579,10 +2592,13 @@ def Stream_TimepointImmediateOp : Stream_PureOp<"timepoint.immediate", [ `=` `` `>` type($result_timepoint) }]; + let extraClassDeclaration = [{ + SmallVector getAwaitTimepoints() { return {}; } + }]; + let hasFolder = 1; } - def Stream_TimepointImportOp : Stream_PureOp<"timepoint.import", [ Stream_AffinityOp, ]> { @@ -2728,6 +2744,10 @@ def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [ Value getResultSize(unsigned idx) { return getResourceOperandSizes()[idx]; } + SmallVector getAwaitTimepoints() { + if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; + } + Value getResultTimepoint() { return {}; } }]; let hasVerifier = 1; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp index 5fc03a14da87..b1a140fb0a02 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp @@ -86,7 +86,7 @@ class GlobalPVS : public DFX::StateWrapper< } static const char ID; - const std::string getAsStr() const override { + const std::string getAsStr(AsmState &asmState) const override { return getPVSAsStr(getState()); } @@ -119,7 +119,7 @@ class ValuePVS : public DFX::StateWrappergetAssumed() & NOT_BY_REFERENCE) == NOT_BY_REFERENCE; } - const std::string getAsStr() const override { + const std::string getAsStr(AsmState &asmState) const override { std::string str; auto append = [&](const char *part) { if (!str.empty()) str += '|'; @@ -217,8 +217,9 @@ class ArgumentSemantics if (auto arg = operand.get().dyn_cast()) { auto &argumentSemantics = solver.getElementFor( *this, Position::forValue(operand.get()), DFX::Resolution::REQUIRED); - LLVM_DEBUG(llvm::dbgs() << " pred is arg; combining state: " - << argumentSemantics.getAsStr() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << " pred is arg; combining state: " + << argumentSemantics.getAsStr(solver.getAsmState()) << "\n"); getState() ^= argumentSemantics.getState(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp new file mode 100644 index 000000000000..7d276670fe3d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp @@ -0,0 +1,936 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h" +#include "iree/compiler/Dialect/Util/Analysis/Explorer.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Utils/PassUtils.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-stream-elide-timepoints" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { +namespace { + +//===----------------------------------------------------------------------===// +// Resource usage query/application patterns +//===----------------------------------------------------------------------===// + +// Returns true if |value| is defined as a #stream.timepoint.immediate. +static bool isDefinedImmediate(Value value) { + return isa_and_nonnull( + value.getDefiningOp()); +} + +// Tracks whether a util.global of !stream.timepoint is immediately resolved. +// Boolean state will be set to false if any stores are non-immediate. +class IsGlobalImmediate + : public DFX::StateWrapper< + DFX::BooleanState, DFX::TypedOperationElement> { + public: + using BaseType = + DFX::StateWrapper>; + + static IsGlobalImmediate &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) IsGlobalImmediate(pos)); + } + + bool isImmediate() const { return isAssumed(); } + + const std::string getName() const override { return "IsGlobalImmediate"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override { + return std::string("is_immediate: ") + std::to_string(isAssumed()); + } + + private: + explicit IsGlobalImmediate(const Position &pos) : BaseType(pos) {} + + void initializeOperation(IREE::Util::GlobalOp globalOp, + DFX::Solver &solver) override { + // Immutable constant globals are all immediate. Initialized globals may + // end up not being immediate and we'll need to analyze. + if (!globalOp.getIsMutable() && globalOp.getInitialValue().has_value()) { + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] immutable immediate global: "; + globalOp.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + setKnown(true); + indicateOptimisticFixpoint(); + return; + } + + // Globals must have been analyzed in order to be tracked. + // Indirectly-accessed globals are not currently supported. + auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp); + if (!globalInfo || globalInfo->isIndirect) { + LLVM_DEBUG({ + llvm::dbgs() + << "[ElideTimepoints] unanalyzed/indirect global ignored: "; + globalOp.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + setKnown(false); + indicatePessimisticFixpoint(); + return; + } + + // Assume true until proven otherwise. + setAssumed(true); + } + + ChangeStatus updateOperation(IREE::Util::GlobalOp globalOp, + DFX::Solver &solver) override; + + friend class DFX::Solver; +}; +const char IsGlobalImmediate::ID = 0; + +// Tracks whether a !stream.timepoint is immediately resolved. +// Boolean state will be set to false if any sources are non-immediate. +class IsImmediate + : public DFX::StateWrapper { + public: + using BaseType = DFX::StateWrapper; + + static IsImmediate &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) IsImmediate(pos)); + } + + bool isImmediate() const { return isAssumed(); } + + const std::string getName() const override { return "IsImmediate"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override { + return std::string("is_immediate: ") + std::to_string(isAssumed()); + } + + private: + explicit IsImmediate(const Position &pos) : BaseType(pos) {} + + void initializeValue(Value value, DFX::Solver &solver) override { + // Immediate timepoints (constant resolved) are always available and cover + // everything. We check for this as a special case to short-circuit the + // solver. + if (isDefinedImmediate(value)) { + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] defined immediate: "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + setKnown(true); + indicateOptimisticFixpoint(); + return; + } + + // Assume true until proven otherwise. + setAssumed(true); + } + + ChangeStatus updateValue(Value value, DFX::Solver &solver) override { + StateType newState = getState(); + + auto traversalResult = TraversalResult::COMPLETE; + + // Scan IR to see if we can identify whether this definitely comes from an + // immediate op. This will reach across block and call edges and may fan out + // into many incoming ops - all of them must be immediate for this op to be + // considered immediate. + traversalResult |= + solver.getExplorer().walkDefiningOps(value, [&](OpResult result) { + updateFromDefiningOp(newState, value, result, solver); + return WalkResult::advance(); + }); + + if (traversalResult == TraversalResult::INCOMPLETE) { + newState.indicatePessimisticFixpoint(); + } + + return DFX::clampStateAndIndicateChange(getState(), newState); + } + + // Updates the usage based on the op defining the value. + void updateFromDefiningOp(StateType &newState, Value value, OpResult result, + DFX::Solver &solver) { + TypeSwitch(result.getOwner()) + .Case([&](IREE::Util::GlobalLoadOp op) { + auto *globalInfo = + solver.getExplorer().queryGlobalInfoFrom(op.getGlobal(), op); + if (!globalInfo || globalInfo->isIndirect) { + LLVM_DEBUG( + { + llvm::dbgs() + << "[ElideTimepoints] indirect usage global backing "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "; marking undef\n"; + }); + newState.indicatePessimisticFixpoint(); + return; + } + auto isImmediate = solver.getElementFor( + *this, Position::forOperation(globalInfo->op), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] global load "; + isImmediate.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= isImmediate.getState(); + }) + .Case([&](IREE::Stream::TimepointImmediateOp op) { + // Defined by an immediate op; definitely immediate. + newState.setAssumed(true); + }) + .Case([&](IREE::Stream::TimepointJoinOp op) { + // Only immediate if all inputs to the join are immediate. + for (auto operand : op.getAwaitTimepoints()) { + auto isImmediate = solver.getElementFor( + *this, Position::forValue(operand), DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] join operand "; + isImmediate.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= isImmediate.getState(); + } + }) + .Case([&](IREE::Stream::TimelineOpInterface op) { + // Defined by a timeline operation that ensures it's never immediate. + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] non-immediate timeline op: "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState.indicatePessimisticFixpoint(); + }) + .Case([&](arith::SelectOp op) { + auto isTrueImmediate = solver.getElementFor( + *this, Position::forValue(op.getTrueValue()), + DFX::Resolution::REQUIRED); + auto isFalseImmediate = solver.getElementFor( + *this, Position::forValue(op.getFalseValue()), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] select join "; + isTrueImmediate.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " OR "; + isFalseImmediate.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= isTrueImmediate.getState(); + newState ^= isFalseImmediate.getState(); + }) + // Allowed because traversal will take care of things: + .Case([&](mlir::CallOpInterface) {}) + .Case([&](mlir::BranchOpInterface) {}) + .Case([&](scf::IfOp) {}) + .Case([&](scf::ForOp) {}) + .Default([&](Operation *op) { + // Unknown op defines the value - we can't make any assumptions. + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] unknown usage of "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " by " << op->getName() << "\n"; + }); + newState.indicatePessimisticFixpoint(); + }); + } + + friend class DFX::Solver; +}; +const char IsImmediate::ID = 0; + +ChangeStatus IsGlobalImmediate::updateOperation(IREE::Util::GlobalOp globalOp, + DFX::Solver &solver) { + IsGlobalImmediate::StateType newState = getState(); + + auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp); + assert(globalInfo && "analysis required"); + + // Walk all stores and clamp to their status. + for (auto storeOp : globalInfo->getStores()) { + auto isImmediate = solver.getElementFor( + *this, Position::forValue(storeOp.getValue()), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] global store: "; + storeOp.getValue().printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "; "; + isImmediate.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= isImmediate; + } + + return DFX::clampStateAndIndicateChange(getState(), newState); +} + +class TimepointCoverage + : public DFX::StateWrapper, + DFX::ValueElement> { + public: + using BaseType = + DFX::StateWrapper, DFX::ValueElement>; + + static TimepointCoverage &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) TimepointCoverage(pos)); + } + + const std::string getName() const override { return "TimepointCoverage"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + // Returns true if the given |value| is known to be covered by this value + // indicating that any time this value is reached |value| must also have been. + bool covers(Value value) const { return getAssumedSet().contains(value); } + + const std::string getAsStr(AsmState &asmState) const override { + std::string str; + llvm::raw_string_ostream sstream(str); + sstream << "covered: "; + if (isValidState()) { + sstream << "["; + if (isUndefContained()) { + sstream << "undef, "; + } + llvm::interleaveComma(getAssumedSet(), sstream, [&](Value value) { + value.printAsOperand(sstream, asmState); + sstream << "(" << (void *)value.getImpl() << ")"; + }); + sstream << "]"; + } else { + sstream << "(invalid)"; + } + sstream.flush(); + return str; + } + + private: + explicit TimepointCoverage(const Position &pos) : BaseType(pos) {} + + void initializeValue(Value value, DFX::Solver &solver) override { + // Immediate timepoints (constant resolved) are always available and cover + // everything. We check for this as a special case to short-circuit the + // solver. + if (isDefinedImmediate(value)) { + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] defined immediate: "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + unionAssumed(value); + indicateOptimisticFixpoint(); + return; + } + } + + ChangeStatus updateValue(Value value, DFX::Solver &solver) override { + StateType newState; + + // Intersect coverage of all incoming block edge operands. + // This will also step outside the entry block and into callee functions. + // The intersection prevents back-edges from polluting block arguments. + auto gatherBlockOperands = [&](BlockArgument blockArg) { + StateType uniformState; + bool firstEdge = true; + if (solver.getExplorer().walkIncomingBlockArgument( + blockArg, [&](Block *sourceBlock, Value operand) { + auto operandCoverage = solver.getElementFor( + *this, Position::forValue(operand), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() + << "[ElideTimepoints] intersect incoming branch operand "; + operandCoverage.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + if (firstEdge) { + uniformState = operandCoverage.getState(); + firstEdge = false; + } else { + uniformState.intersectAssumed(operandCoverage.getState()); + } + return WalkResult::advance(); + }) == TraversalResult::INCOMPLETE) { + LLVM_DEBUG(llvm::dbgs() << "[ElideTimepoints] incomplete branch arg " + "traversal; assuming unknown\n"); + uniformState.unionAssumedWithUndef(); + } + newState ^= uniformState; + newState.unionAssumed(blockArg); + }; + + // Intersect coverage of all callee/child region return operands. + // The intersection prevents multiple return sites from interfering. + auto gatherRegionReturns = [&](Operation *regionOp, unsigned resultIndex) { + StateType uniformState; + bool firstEdge = true; + if (solver.getExplorer().walkReturnOperands( + regionOp, [&](OperandRange operands) { + auto operand = operands[resultIndex]; + auto operandCoverage = solver.getElementFor( + *this, Position::forValue(operand), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() + << "[ElideTimepoints] intersect incoming return operand "; + operandCoverage.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + if (firstEdge) { + uniformState = operandCoverage.getState(); + firstEdge = false; + } else { + uniformState.intersectAssumed(operandCoverage.getState()); + } + return WalkResult::advance(); + }) == TraversalResult::INCOMPLETE) { + LLVM_DEBUG(llvm::dbgs() << "[ElideTimepoints] incomplete region " + "traversal; assuming unknown\n"); + uniformState.unionAssumedWithUndef(); + } + newState ^= uniformState; + }; + + auto *definingOp = value.getDefiningOp(); + if (auto blockArg = value.dyn_cast()) { + // Block arguments need an intersection of all incoming branch/call edges. + gatherBlockOperands(blockArg); + return DFX::clampStateAndIndicateChange(getState(), newState); + } + + TypeSwitch(definingOp) + .Case([&](IREE::Stream::TimelineOpInterface timelineOp) { + // Value defined from a timeline op and we can mark all awaits of + // the op as covered by the result. + for (auto operand : timelineOp.getAwaitTimepoints()) { + auto operandCoverage = solver.getElementFor( + *this, Position::forValue(operand), DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] dependent timeline operand "; + operandCoverage.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState.unionAssumed(operand); + newState &= operandCoverage; + } + // Timepoints cover themselves; this is redundant but simplifies the + // set logic later on. + if (auto resultTimepoint = timelineOp.getResultTimepoint()) { + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] produced timeline result "; + resultTimepoint.printAsOperand(llvm::dbgs(), + solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + } + }) + .Case([&](mlir::CallOpInterface callOp) { + // Step into callees and get a coverage intersection of all return + // sites. + auto callableOp = + callOp.resolveCallable(&solver.getExplorer().getSymbolTables()); + unsigned resultIndex = value.cast().getResultNumber(); + gatherRegionReturns(callableOp, resultIndex); + }) + .Case([&](RegionBranchOpInterface regionOp) { + // Step into regions and get a coverage intersection of all return + // sites. + unsigned resultIndex = value.cast().getResultNumber(); + gatherRegionReturns(regionOp, resultIndex); + }) + .Case([&](arith::SelectOp op) { + auto trueCoverage = solver.getElementFor( + *this, Position::forValue(op.getTrueValue()), + DFX::Resolution::REQUIRED); + auto falseCoverage = solver.getElementFor( + *this, Position::forValue(op.getFalseValue()), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] select join "; + trueCoverage.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " AND "; + falseCoverage.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState &= trueCoverage; + newState &= falseCoverage; + }); + + return DFX::clampStateAndIndicateChange(getState(), newState); + } + + friend class DFX::Solver; +}; +const char TimepointCoverage::ID = 0; + +class TimepointCoverageAnalysis { + public: + explicit TimepointCoverageAnalysis(Operation *rootOp) + : explorer(rootOp, TraversalAction::SHALLOW), + solver(explorer, allocator) { + explorer.setOpAction(TraversalAction::RECURSE); + explorer.setOpAction(TraversalAction::RECURSE); + explorer.setOpAction(TraversalAction::RECURSE); + explorer.setDialectAction( + TraversalAction::RECURSE); + // Ignore the contents of executables (linalg goo, etc) and execution + // regions (they don't impact timepoints). + explorer.setOpAction(TraversalAction::IGNORE); + explorer.setOpAction( + TraversalAction::SHALLOW); + explorer.setOpAction(TraversalAction::SHALLOW); + explorer.initialize(); + + assert(rootOp->getNumRegions() == 1 && "expected module-like root op"); + topLevelOps = llvm::to_vector<4>( + rootOp->getRegions().front().getOps()); + } + + AsmState &getAsmState() { return solver.getAsmState(); } + Explorer &getExplorer() { return explorer; } + + // Runs analysis and populates the state cache. + // May fail if analysis cannot be completed due to unsupported or unknown IR. + LogicalResult run() { + explorer.forEachGlobal([&](const auto *globalInfo) { + solver.getOrCreateElementFor( + Position::forOperation(globalInfo->op)); + for (auto loadOp : globalInfo->getLoads()) { + solver.getOrCreateElementFor( + Position::forValue(loadOp.getResult())); + } + }); + std::function seedRegion; + seedRegion = [&](Region ®ion) { + for (auto &block : region) { + // Seed all block arguments. + for (auto arg : block.getArguments()) { + if (arg.getType().isa()) { + solver.getOrCreateElementFor(Position::forValue(arg)); + } + } + + // Seed the timepoints created from any timeline ops. + for (auto op : block.getOps()) { + for (auto operand : op.getAwaitTimepoints()) { + solver.getOrCreateElementFor( + Position::forValue(operand)); + solver.getOrCreateElementFor( + Position::forValue(operand)); + } + if (auto resultTimepoint = op.getResultTimepoint()) { + solver.getOrCreateElementFor( + Position::forValue(resultTimepoint)); + solver.getOrCreateElementFor( + Position::forValue(resultTimepoint)); + } + } + + // Seed all terminator operands. + if (auto *terminatorOp = block.getTerminator()) { + for (auto operand : terminatorOp->getOperands()) { + if (operand.getType().isa()) { + solver.getOrCreateElementFor( + Position::forValue(operand)); + solver.getOrCreateElementFor( + Position::forValue(operand)); + } + } + } + } + + // Walk into nested ops. + region.walk([&](RegionBranchOpInterface nestedOp) { + for (auto &nestedRegion : nestedOp->getRegions()) { + seedRegion(nestedRegion); + } + }); + }; + for (auto callableOp : getTopLevelOps()) { + auto *region = callableOp.getCallableRegion(); + if (!region || region->empty()) continue; + seedRegion(*region); + } + + // Run solver to completion. + auto result = solver.run(); + LLVM_DEBUG(solver.print(llvm::dbgs())); + return result; + } + + // Returns a list of all top-level callable ops in the root op. + ArrayRef getTopLevelOps() const { + return topLevelOps; + } + + // Returns true if |value| is known to be immediately resolved. + bool isImmediate(Value value) { + if (isDefinedImmediate(value)) return true; + auto &isImmediate = + solver.getOrCreateElementFor(Position::forValue(value)); + return isImmediate.isValidState() && isImmediate.isKnown(); + } + + // Union all transitively reached timepoints by the time |value| is reached. + bool unionTransitivelyReachedTimepoints(Value value, SetVector &set) { + auto coverage = solver.getOrCreateElementFor( + Position::forValue(value)); + if (!coverage.isValidState() || coverage.isUndefContained()) return false; + for (auto reached : coverage.getAssumedSet()) { + set.insert(reached); + } + return true; + } + + private: + Explorer explorer; + llvm::BumpPtrAllocator allocator; + DFX::Solver solver; + SmallVector topLevelOps; +}; + +// Prunes |possibleTimepoints| into a set of required timepoints. +// Any timepoints not in the resulting set are required. +static SetVector buildRequiredCoverageSet( + SmallVector possibleTimepoints, + TimepointCoverageAnalysis &analysis) { + // Build a map that effectively tracks an incoming edge counter for each + // timepoint. Values with no incoming edges are required. + DenseMap coverageMap; + for (auto possibleTimepoint : possibleTimepoints) { + // Query all transitively reached timepoints from this potentially required + // timepoint. If analysis failed we skip it and ensure the timepoint is + // pulled in unless something else covers it. + SetVector reachedTimepoints; + bool isValid = analysis.unionTransitivelyReachedTimepoints( + possibleTimepoint, reachedTimepoints); + if (isValid) { + for (auto reachedTimepoint : reachedTimepoints) { + // TODO(benvanik): avoid self-references so we don't need this check. + if (reachedTimepoint == possibleTimepoint) continue; + ++coverageMap[reachedTimepoint]; + } + } + } + // Any possibly required timepoint that has no coverage is a root (no refs) + // and is required. + SetVector requiredTimepoints; + for (auto possibleTimepoint : possibleTimepoints) { + auto it = coverageMap.find(possibleTimepoint); + if (it == coverageMap.end() || it->second <= 0) { + LLVM_DEBUG({ + llvm::dbgs() << " ++ requiring uncovered "; + possibleTimepoint.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + llvm::dbgs() << " (root)\n"; + }); + requiredTimepoints.insert(possibleTimepoint); + } else { + LLVM_DEBUG({ + llvm::dbgs() << " -- omitting covered "; + possibleTimepoint.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + llvm::dbgs() << "\n"; + }); + } + } + return requiredTimepoints; +} + +// Tries to elide timepoints nested within |region| when safe. +// Returns true if any ops were elided. +static bool tryElideTimepointsInRegion(Region ®ion, + TimepointCoverageAnalysis &analysis) { + bool didChange = false; + + // We batch up all results we're going to change to prevent SSA value + // breakages in the debug print out. This maps old->new values. + DenseMap pendingReplacements; + + // Inserts an immediate timepoint or reuses an existing replacement (if + // any). + auto makeImmediate = [&](Value elidedTimepoint, OpBuilder builder) -> Value { + auto existingReplacement = pendingReplacements.find(elidedTimepoint); + if (existingReplacement != pendingReplacements.end()) { + return existingReplacement->second; + } + return builder.create( + elidedTimepoint.getLoc()); + }; + + // Elides |elidedTimepoint| by replacing all its uses by |op| with an + // immediate timepoint value. + auto elideTimepointOperand = [&](Operation *op, Value elidedTimepoint) { + if (isDefinedImmediate(elidedTimepoint)) return; // already immediate + auto immediateTimepoint = makeImmediate(elidedTimepoint, OpBuilder(op)); + elidedTimepoint.replaceUsesWithIf( + immediateTimepoint, + [&](OpOperand &operand) { return operand.getOwner() == op; }); + didChange = true; + }; + + // Elides all timepoint operands of |op| that are immediately resolved. + auto elideTimepointOperands = [&](Operation *op) { + for (auto operand : llvm::make_early_inc_range(op->getOperands())) { + if (!operand.getType().isa()) continue; + if (isDefinedImmediate(operand)) continue; + if (analysis.isImmediate(operand)) { + LLVM_DEBUG({ + llvm::dbgs() << " >>> eliding known-immediate operand "; + operand.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + llvm::dbgs() << " consumed by " << op->getName() << "\n"; + }); + elideTimepointOperand(op, operand); + } + } + }; + + // Elides |elidedTimepoint| by replacing all its uses with an immediate + // timepoint value. The original value will end up with zero uses. + auto elideTimepointResult = [&](Operation *op, Value elidedTimepoint) { + if (elidedTimepoint.use_empty()) return; // no-op + if (isDefinedImmediate(elidedTimepoint)) return; // already immediate + OpBuilder afterBuilder(op); + afterBuilder.setInsertionPointAfterValue(elidedTimepoint); + Value immediateTimepoint = + afterBuilder.create( + elidedTimepoint.getLoc()); + // Defer actually swapping until later. + pendingReplacements.insert( + std::make_pair(elidedTimepoint, immediateTimepoint)); + didChange = true; + }; + + // Elides all timepoint results of |op| that are immediately resolved. + auto elideTimepointResults = [&](Operation *op) { + // Reverse so that we insert in return order: + // %0, %1 = ... + // %imm0 = immediate + // %imm1 = immediate + for (auto result : llvm::reverse(op->getResults())) { + if (!result.getType().isa()) continue; + if (isDefinedImmediate(result)) continue; + if (analysis.isImmediate(result)) { + LLVM_DEBUG({ + llvm::dbgs() << " >>> eliding known-immediate result "; + result.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + llvm::dbgs() << " produced by " << op->getName() << " (result " + << result.getResultNumber() << ")\n"; + }); + elideTimepointResult(op, result); + } + } + }; + + // Processes timeline |op| by eliding its await and result timepoints if + // possible. + auto processTimelineOp = [&](IREE::Stream::TimelineOpInterface op) { + auto resultTimepoint = op.getResultTimepoint(); + auto awaitTimepoints = op.getAwaitTimepoints(); + if (awaitTimepoints.empty()) return; + + LLVM_DEBUG({ + llvm::dbgs() << "[ElideTimepoints] pruning " << op->getName() + << " await("; + llvm::interleaveComma(awaitTimepoints, llvm::dbgs(), [&](Value value) { + value.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + }); + llvm::dbgs() << ")"; + if (resultTimepoint) { + llvm::dbgs() << " producing "; + resultTimepoint.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + } + llvm::dbgs() << "\n"; + }); + + // If the result of the op is immediate then we can elide the resulting + // timepoint. + if (resultTimepoint && analysis.isImmediate(resultTimepoint)) { + LLVM_DEBUG({ + llvm::dbgs() << " >>> eliding entire known-immediate result "; + resultTimepoint.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + llvm::dbgs() << " produced by " << op->getName() << "\n"; + }); + elideTimepointResult(op, resultTimepoint); + } + + // Prune all immediately reached timepoints. + // This may let us avoid doing the full pruning pass by getting us down to + // 0 or 1 timepoints. + SmallVector possibleTimepoints; + for (auto awaitTimepoint : awaitTimepoints) { + if (analysis.isImmediate(awaitTimepoint)) { + // Timepoint is definitely immediate and can be pruned. + LLVM_DEBUG({ + llvm::dbgs() << " >>> eliding use of known-immediate "; + awaitTimepoint.printAsOperand(llvm::dbgs(), analysis.getAsmState()); + llvm::dbgs() << " in " << op->getName() << "\n"; + }); + elideTimepointOperand(op, awaitTimepoint); + } else { + // May be immediate but not certain; preserve. + possibleTimepoints.push_back(awaitTimepoint); + } + } + + // If there's only one timepoint we don't have to worry with coverage. + if (possibleTimepoints.size() <= 1) return; + + // Perform the analysis on the possible timepoints to find which are covered + // by others and elide all of those known-covered. + auto requiredTimepoints = + buildRequiredCoverageSet(possibleTimepoints, analysis); + for (auto possibleTimepoint : possibleTimepoints) { + if (!requiredTimepoints.contains(possibleTimepoint)) { + // Timepoint is covered (or immediate) and can be pruned. + LLVM_DEBUG({ + llvm::dbgs() << " >>> eliding use of covered "; + possibleTimepoint.printAsOperand(llvm::dbgs(), + analysis.getAsmState()); + llvm::dbgs() << "(" << (void *)possibleTimepoint.getImpl() << ")\n"; + }); + elideTimepointOperand(op, possibleTimepoint); + } + } + }; + + // Walk all blocks and elide timepoints. + // We walk pre-order to make the debug output easier to read. + region.walk([&](Operation *op) { + // TODO(benvanik): handle more ops from scf or other dialects. + TypeSwitch(op) + .Case([&](IREE::Stream::TimelineOpInterface op) { + // Most of the interesting stream.* stuff happens here. + processTimelineOp(op); + }) + .Case( + [&](Operation *op) { elideTimepointResults(op); }) + .Case([&](Operation *op) { + elideTimepointOperands(op); + elideTimepointResults(op); + }) + .Case( + [&](Operation *op) { elideTimepointOperands(op); }) + .Case( + [&](Operation *op) { elideTimepointOperands(op); }); + }); + + // Process elided results; we do this afterward to keep the debug output + // cleaner by not adding <>. + for (auto replacement : pendingReplacements) { + replacement.first.replaceAllUsesWith(replacement.second); + } + + return didChange; +} + +//===----------------------------------------------------------------------===// +// -iree-stream-elide-timepoints +//===----------------------------------------------------------------------===// + +// Elides waits on timepoints that are known to be reached by a dependent +// timepoint. We err on the side of additional timepoints if we can't guarantee +// that a particular wait is covered. +// +// Example: +// %timepoint0 = ... +// %timepoint1 = ... await(%timepoint0) +// %timepoint2 = stream.timepoint.join max(%timepoint0, %timepoint1) +// -> +// %timepoint0 = ... +// %timepoint1 = ... await(%timepoint0) +// %timepoint2 = stream.timepoint.join max(%timepoint1) +// -> (canonicalization) -> +// %timepoint0 = ... +// %timepoint1 = ... await(%timepoint0) +// %timepoint2 = %timepoint1 +class ElideTimepointsPass : public ElideTimepointsBase { + public: + ElideTimepointsPass() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + if (moduleOp.getBody()->empty()) return; + + // Perform whole-program analysis to find for each timepoint what other + // timepoints are known to be reached. + TimepointCoverageAnalysis analysis(moduleOp); + if (failed(analysis.run())) { + moduleOp.emitError() << "failed to solve for timepoint coverage"; + return signalPassFailure(); + } + + bool didChange = false; + + // Apply analysis by replacing known-covered timepoint usage with immediate + // values. If we change something we'll indicate that so that the parent + // fixed-point iteration continues. + for (auto callableOp : analysis.getTopLevelOps()) { + auto *region = callableOp.getCallableRegion(); + if (!region || region->empty()) continue; + didChange = tryElideTimepointsInRegion(*region, analysis) || didChange; + } + + if (didChange) signalFixedPointModified(moduleOp); + } +}; + +} // namespace + +std::unique_ptr> createElideTimepointsPass() { + return std::make_unique(); +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index ab42cd77ed3b..2582782342a1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Passes.h" @@ -26,13 +27,13 @@ using FunctionLikeNest = MultiOpNest; //===----------------------------------------------------------------------===// static void addCleanupPatterns(OpPassManager &passManager) { - // Standard MLIR cleanup. - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - - // Simplify util.global accesses; this can help with data flow tracking as - // redundant store-loads are removed. FunctionLikeNest(passManager) + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + + // Simplify util.global accesses; this can help with data flow tracking as + // redundant store-loads are removed. .addPass(IREE::Util::createSimplifyGlobalAccessesPass); // Cleanup and canonicalization of util.global (and other util ops). @@ -150,8 +151,6 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager, passManager.addPass(IREE::Stream::createPropagateTimepointsPass()); addCleanupPatterns(passManager); - // TODO(benvanik): remove covered timepoints in awaits (dominance). - // Everything must now be in stream.async.* form. passManager.addPass(IREE::Stream::createVerifyLoweringToAsyncPass()); } @@ -209,6 +208,38 @@ void buildStreamOptimizationPassPipeline( // cause duplication. Run CSE to collapse. addCleanupPatterns(passManager); + // If any scf ops crept in we get rid of them here. We should be able to + // support them all the way through the stream dialect but some passes are not + // currently set up to handle them (such as elide timepoints). + FunctionLikeNest(passManager).addPass(mlir::createConvertSCFToCFPass); + + //---------------------------------------------------------------------------- + // Whole-program scheduling optimization + //---------------------------------------------------------------------------- + + { + // We run these under a fixed-point iteration such that we can perform + // inter-procedural, intra-procedural, and canonicalization as separably + // verifiable/reusable passes alongside the custom stream ones. IPO will + // fold duplicate arguments/results and inline constants to allow the local + // optimizations to work more effectively. + OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName()); + + // IPO and other cleanups. + addCleanupPatterns(ipoPipeline); + + // TODO(#9747): elide timepoints that are know-reached due to host + // synchronization via stream.timepoint.await. + + // Elide timepoints in dependency chains where one is known to have been + // reached by the time another is (A -> B -> A|C). + ipoPipeline.addPass(IREE::Stream::createElideTimepointsPass()); + + // Run fixed-point iteration on the IPO pipeline. + passManager.addPass( + IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline))); + } + //---------------------------------------------------------------------------- // Binding optimization //---------------------------------------------------------------------------- diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h index 4a22df137f5b..cb8f02154f15 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h @@ -124,6 +124,7 @@ std::unique_ptr> createScheduleConcurrencyPass(); std::unique_ptr> createPropagateTimepointsPass(); +std::unique_ptr> createElideTimepointsPass(); //===----------------------------------------------------------------------===// // Allocation and command issuing diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index 96a81345e3ad..51c910075852 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -113,6 +113,14 @@ def PropagateTimepoints : }]; } +def ElideTimepoints : + Pass<"iree-stream-elide-timepoints", "mlir::ModuleOp"> { + let summary = "Elides timepoints that are known to be covered by dependent timepoints."; + let constructor = [{ + mlir::iree_compiler::IREE::Stream::createElideTimepointsPass() + }]; +} + //===----------------------------------------------------------------------===// // Allocation and command issuing //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD index b46e40ad4419..5bfe48c90d20 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD @@ -21,6 +21,8 @@ iree_lit_test_suite( "convert_to_stream.mlir", "dump_statistics.mlir", "elide_async_copies.mlir", + "elide_timepoints_coverage.mlir", + "elide_timepoints_immediate.mlir", "encode_device_tensors.mlir", "encode_host_tensors.mlir", "fold_globals.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 0c41984b8054..212ea0460a07 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -18,6 +18,8 @@ iree_lit_test_suite( "convert_to_stream.mlir" "dump_statistics.mlir" "elide_async_copies.mlir" + "elide_timepoints_coverage.mlir" + "elide_timepoints_immediate.mlir" "encode_device_tensors.mlir" "encode_host_tensors.mlir" "fold_globals.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_timepoints_coverage.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_timepoints_coverage.mlir new file mode 100644 index 000000000000..48280db5a716 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_timepoints_coverage.mlir @@ -0,0 +1,288 @@ +// RUN: iree-opt --split-input-file --iree-stream-elide-timepoints %s | FileCheck %s + +// Tests that we don't (currently) do anything with global forwarding. +// Generic util analysis passes operating on globals can do things like folding. +// We just want to make sure here that we are preserving the global behavior. + +util.global private mutable @global0 : !stream.timepoint +util.global private mutable @global1 : !stream.timepoint + +util.initializer { + %t0 = stream.cmd.execute with() {} => !stream.timepoint + util.global.store %t0, @global0 : !stream.timepoint + %t1 = stream.cmd.execute await(%t0) => with() {} => !stream.timepoint + util.global.store %t1, @global1 : !stream.timepoint + util.initializer.return +} + +// CHECK-LABEL: @initializedGlobals +func.func private @initializedGlobals() -> !stream.timepoint { + // CHECK: %[[GLOBAL0:.+]] = util.global.load @global0 + %global0 = util.global.load @global0 : !stream.timepoint + // CHECK: %[[GLOBAL1:.+]] = util.global.load @global1 + %global1 = util.global.load @global1 : !stream.timepoint + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[GLOBAL0]], %[[GLOBAL1]]) + %join = stream.timepoint.join max(%global0, %global1) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that meaningful timeline ops are never marked immediate. + +// CHECK-LABEL: @nonImmediate +func.func private @nonImmediate() -> !stream.timepoint { + // CHECK: %[[EXECUTE:.+]] = stream.cmd.execute + %0 = stream.cmd.execute with() {} => !stream.timepoint + // CHECK: return %[[EXECUTE]] + return %0 : !stream.timepoint +} + +// ----- + +// Tests that coverage propagates through timeline ops. Here %exec0 is covered +// by both %exec1a and %exec1b and does not need to be joined. + +// CHECK-LABEL: @joinChained +func.func @joinChained() -> !stream.timepoint { + // CHECK: %[[EXEC0:.+]] = stream.cmd.execute with + %exec0 = stream.cmd.execute with() {} => !stream.timepoint + // CHECK: %[[EXEC1A:.+]] = stream.cmd.execute await(%[[EXEC0]]) + %exec1a = stream.cmd.execute await(%exec0) => with() {} => !stream.timepoint + // CHECK: %[[EXEC1B:.+]] = stream.cmd.execute await(%[[EXEC0]]) + %exec1b = stream.cmd.execute await(%exec0) => with() {} => !stream.timepoint + // CHECK: %[[EXEC0_IMM:.+]] = stream.timepoint.immediate + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[EXEC0_IMM]], %[[EXEC1A]], %[[EXEC1B]]) + %join = stream.timepoint.join max(%exec0, %exec1a, %exec1b) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that coverage propagates through a select: %exec0 is covered by both +// the true and false conditions and does not need to be joined. + +// CHECK-LABEL: @selectCovered +// CHECK-SAME: (%[[COND:.+]]: i1) +func.func @selectCovered(%cond: i1) -> !stream.timepoint { + // CHECK: %[[EXEC0:.+]] = stream.cmd.execute + %exec0 = stream.cmd.execute with() {} => !stream.timepoint + // CHECK: %[[EXEC1A:.+]] = stream.cmd.execute await(%[[EXEC0]]) + %exec1a = stream.cmd.execute await(%exec0) => with() {} => !stream.timepoint + // CHECK: %[[EXEC1B:.+]] = stream.cmd.execute await(%[[EXEC0]]) + %exec1b = stream.cmd.execute await(%exec0) => with() {} => !stream.timepoint + // CHECK: %[[SELECT:.+]] = arith.select %[[COND]], %[[EXEC1A]], %[[EXEC1B]] + %select = arith.select %cond, %exec1a, %exec1b : !stream.timepoint + // CHECK: %[[EXEC0_IMM:.+]] = stream.timepoint.immediate + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[EXEC0_IMM]], %[[SELECT]]) + %join = stream.timepoint.join max(%exec0, %select) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that a timepoint passed along a call edge is propagated. +// %t0/%t1 are covered by the call result %call that joins the two together. + +// CHECK-LABEL: func @caller +// CHECK-SAME: (%[[T0:.+]]: !stream.timepoint, %[[T1:.+]]: !stream.timepoint) +func.func @caller(%t0: !stream.timepoint, %t1: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[CALL:.+]] = call @callee(%[[T0]], %[[T1]]) + %call = call @callee(%t0, %t1) : (!stream.timepoint, !stream.timepoint) -> !stream.timepoint + // CHECK-DAG: %[[T0_COVERED:.+]] = stream.timepoint.immediate + // CHECK-DAG: %[[T1_COVERED:.+]] = stream.timepoint.immediate + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[T0_COVERED]], %[[T1_COVERED]], %[[CALL]]) + %join = stream.timepoint.join max(%t0, %t1, %call) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} +// CHECK-LABEL: func private @callee +func.func private @callee(%t0a: !stream.timepoint, %t0b: !stream.timepoint) -> !stream.timepoint { + // CHECK-NOT: stream.timepoint.immediate + // CHECK: %[[JOIN_CALLEE:.+]] = stream.timepoint.join max + %t1 = stream.timepoint.join max(%t0a, %t0b) => !stream.timepoint + // CHECK: return %[[JOIN_CALLEE]] + return %t1 : !stream.timepoint +} + +// ----- + +// Tests that duplicate call args/results are handled correctly. +// Ideally we're running in as part of a fixed-point iteration with IPO that +// removes the dupes and lets us focus on simpler cases. For now we don't do +// anything clever with folding the call results even though we know they're +// the same and instead just handle coverage (hitting either call results is +// the same as hitting the original arg). + +// CHECK-LABEL: func @callerDupes +func.func @callerDupes(%unknown: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[CALL:.+]]:2 = call @calleeDupes + %call:2 = call @calleeDupes(%unknown, %unknown) : (!stream.timepoint, !stream.timepoint) -> (!stream.timepoint, !stream.timepoint) + // CHECK-NEXT: %[[UNKNOWN_IMM:.+]] = stream.timepoint.immediate + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[UNKNOWN_IMM]], %[[CALL]]#0, %[[CALL]]#1) + %join = stream.timepoint.join max(%unknown, %call#0, %call#1) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} +func.func private @calleeDupes(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> (!stream.timepoint, !stream.timepoint) { + return %arg0, %arg1 : !stream.timepoint, !stream.timepoint +} + +// ----- + +// Tests that calls with non-uniform args still track partial coverage. +// Here the result of @nonUniformCallee always covers %t0 but not %t1 and we're +// able to elide %t0 in the final join. + +// TODO(benvanik): we should also be able to trim the calls/t1 and only use +// %t01 but that needs some work to know that call0 == t0 and call1 == t01. + +// CHECK-LABEL: func @nonUniformCaller +// CHECK-SAME: (%[[T0:.+]]: !stream.timepoint, %[[T1:.+]]: !stream.timepoint) +func.func @nonUniformCaller(%t0: !stream.timepoint, %t1: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[CALL0:.+]] = call @nonUniformCallee(%[[T0]]) + %call0 = call @nonUniformCallee(%t0) : (!stream.timepoint) -> !stream.timepoint + // CHECK: %[[T01:.+]] = stream.timepoint.join max(%[[T0]], %[[T1]]) + %t01 = stream.timepoint.join max(%t0, %t1) => !stream.timepoint + // CHECK: %[[CALL1:.+]] = call @nonUniformCallee(%[[T01]]) + %call1 = call @nonUniformCallee(%t01) : (!stream.timepoint) -> !stream.timepoint + // CHECK: %[[T0_IMM:.+]] = stream.timepoint.immediate + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[T0_IMM]], %[[CALL0]], %[[T1]], %[[CALL1]]) + %join = stream.timepoint.join max(%t0, %call0, %t1, %call1) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} +// CHECK: func private @nonUniformCallee +func.func private @nonUniformCallee(%arg0: !stream.timepoint) -> !stream.timepoint { + return %arg0 : !stream.timepoint +} + +// ----- + +// Tests that timepoints are tracked through branches args. +// In this simple case %bb1_t0 always covers %t0. + +// CHECK-LABEL: func @branch +// CHECK-SAME: (%[[T0:.+]]: !stream.timepoint) +func.func @branch(%t0: !stream.timepoint) -> !stream.timepoint { + // CHECK: cf.br ^bb1 + cf.br ^bb1(%t0 : !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_T0:.+]]: !stream.timepoint) +^bb1(%bb1_t0: !stream.timepoint): + // CHECK: %[[T0_IMM:.+]] = stream.timepoint.immediate + // CHECK-NEXT: %[[JOIN:.+]] = stream.timepoint.join max(%[[T0_IMM]], %[[BB1_T0]]) + %join = stream.timepoint.join max(%t0, %bb1_t0) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that forward edges with convergent timepoints track coverage. +// Here both true and false paths cover %t0 and it can be elided at the join. + +// CHECK-LABEL: func @branchConvergentForwardEdge +// CHECK-SAME: (%[[COND:.+]]: i1, %[[T0:.+]]: !stream.timepoint) +func.func @branchConvergentForwardEdge(%cond: i1, %t0: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[T1A:.+]] = stream.cmd.execute await(%[[T0]]) + %t1a = stream.cmd.execute await(%t0) => with() {} => !stream.timepoint + // CHECK: %[[T1B:.+]] = stream.cmd.execute await(%[[T0]]) + %t1b = stream.cmd.execute await(%t0) => with() {} => !stream.timepoint + // CHECK: cf.cond_br %[[COND]] + // CHECK-SAME: ^bb1(%[[T1A]] : !stream.timepoint), + // CHECK-SAME: ^bb1(%[[T1B]] : !stream.timepoint) + cf.cond_br %cond, ^bb1(%t1a : !stream.timepoint), ^bb1(%t1b : !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_ARG:.+]]: !stream.timepoint) +^bb1(%bb1_arg: !stream.timepoint): + // CHECK: %[[T0_IMM:.+]] = stream.timepoint.immediate + // CHECK-NEXT: %[[JOIN:.+]] = stream.timepoint.join max(%[[T0_IMM]], %[[BB1_ARG]]) + %join = stream.timepoint.join max(%t0, %bb1_arg) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that forward edges with divergent timepoint coverage get propagated. +// %t0 is covered on both paths but %t1 is only covered when %cond == true. + +// CHECK-LABEL: func @branchDivergentForwardEdge +// CHECK-SAME: (%[[COND:.+]]: i1, %[[T0:.+]]: !stream.timepoint, %[[T1:.+]]: !stream.timepoint) +func.func @branchDivergentForwardEdge(%cond: i1, %t0: !stream.timepoint, %t1: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[T01:.+]] = stream.timepoint.join max(%[[T0]], %[[T1]]) + %t01 = stream.timepoint.join max(%t0, %t1) => !stream.timepoint + // CHECK-NEXT: cf.cond_br + // CHECK-SAME: ^bb1(%[[T0]] : !stream.timepoint), + // CHECK-SAME: ^bb1(%[[T01]] : !stream.timepoint) + cf.cond_br %cond, ^bb1(%t0 : !stream.timepoint), ^bb1(%t01 : !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_ARG:.+]]: !stream.timepoint) +^bb1(%bb1_arg: !stream.timepoint): + // CHECK: %[[T0_IMM:.+]] = stream.timepoint.immediate + // CHECK-NEXT: %[[JOIN:.+]] = stream.timepoint.join max(%[[T0_IMM]], %[[T1]], %[[BB1_ARG]]) + %join = stream.timepoint.join max(%t0, %t1, %bb1_arg) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that back edges with divergent timepoints don't get propagated. + +// TODO(benvanik): some way of knowing %t0 is always covered; for now we aren't +// smart enough to track that through and likely need some +// must-be-executed-context-like machinery in order to do so. We just want to +// make sure we're preserving the timepoints here for correctness. + +// CHECK-LABEL: func @branchDivergentBackEdge +// CHECK-SAME: (%[[COND:.+]]: i1, %[[T0:.+]]: !stream.timepoint) +func.func @branchDivergentBackEdge(%cond: i1, %t0: !stream.timepoint) -> !stream.timepoint { + // CHECK: cf.br ^bb1 + cf.br ^bb1(%cond, %t0 : i1, !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_COND:.+]]: i1, %[[BB1_T0:.+]]: !stream.timepoint) +^bb1(%bb1_cond: i1, %bb1_t0: !stream.timepoint): + // CHECK-NOT: stream.timepoint.immediate + // CHECK-NEXT: %[[BB1_T1:.+]] = stream.cmd.execute await(%[[BB1_T0]]) + %bb1_t1 = stream.cmd.execute await(%bb1_t0) => with() {} => !stream.timepoint + // CHECK: %[[FALSE:.+]] = arith.constant false + %cond_false = arith.constant false + // CHECK-NEXT: cf.cond_br + // CHECK-SAME: ^bb1(%[[FALSE]], %[[BB1_T1]] : i1, !stream.timepoint) + // CHECK-SAME: ^bb2(%[[BB1_T1]] : !stream.timepoint) + cf.cond_br %bb1_cond, ^bb1(%cond_false, %bb1_t1 : i1, !stream.timepoint), ^bb2(%bb1_t1 : !stream.timepoint) +// CHECK-NEXT: ^bb2(%[[BB2_T1:.+]]: !stream.timepoint) +^bb2(%bb2_t1: !stream.timepoint): + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[T0]], %[[BB2_T1]]) + %join = stream.timepoint.join max(%t0, %bb2_t1) => !stream.timepoint + // CHECK-NEXT: return %[[JOIN]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that scf.if regions with convergent yields are handled. +// Here %t0 is covered regardless of the %cond and can be elided. + +// CHECK-LABEL: func @scfIfConvergent +// CHECK-SAME: (%[[COND:.+]]: i1, %[[T0:.+]]: !stream.timepoint, %[[T1:.+]]: !stream.timepoint) +func.func @scfIfConvergent(%cond: i1, %t0: !stream.timepoint, %t1: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[IF:.+]] = scf.if + %if = scf.if %cond -> !stream.timepoint { + // CHECK: yield %[[T0]] + scf.yield %t0 : !stream.timepoint + } else { + // CHECK: %[[T01:.+]] = stream.timepoint.join max(%[[T0]], %[[T1]]) + %t01 = stream.timepoint.join max(%t0, %t1) => !stream.timepoint + // CHECK: yield %[[T01]] + scf.yield %t01 : !stream.timepoint + } + // CHECK: %[[T0_IMM:.+]] = stream.timepoint.immediate + // CHECK-NEXT: %[[JOIN:.+]] = stream.timepoint.join max(%[[T0_IMM]], %[[T1]], %[[IF]]) + %join = stream.timepoint.join max(%t0, %t1, %if) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %join : !stream.timepoint +} + +// TODO(benvanik): support scf.for diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_timepoints_immediate.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_timepoints_immediate.mlir new file mode 100644 index 000000000000..147c1c65a07f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_timepoints_immediate.mlir @@ -0,0 +1,378 @@ +// RUN: iree-opt --split-input-file --iree-stream-elide-timepoints %s | FileCheck %s + +// Tests that joins with multiple immediate timepoints are marked as immediate. + +// CHECK-LABEL: @immediateJoin +func.func private @immediateJoin() -> !stream.timepoint { + %imm0 = stream.timepoint.immediate => !stream.timepoint + %imm1 = stream.timepoint.immediate => !stream.timepoint + // CHECK: stream.timepoint.join + // CHECK-NEXT: %[[JOIN_IMM:.+]] = stream.timepoint.immediate + %0 = stream.timepoint.join max(%imm0, %imm1) => !stream.timepoint + // CHECK: return %[[JOIN_IMM]] + return %0 : !stream.timepoint +} + +// ----- + +// Tests that joins with one or more non-immediate timepoints are not elided. + +// CHECK-LABEL: @nonImmediateJoin +// CHECK-SAME: (%[[NON_IMM:.+]]: !stream.timepoint) +func.func @nonImmediateJoin(%arg0: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[NON_IMM]], %[[IMM]]) + %0 = stream.timepoint.join max(%arg0, %imm) => !stream.timepoint + // CHECK: return %[[JOIN]] + return %0 : !stream.timepoint +} + +// ----- + +// Tests that a select between immediate values is marked immediate. + +// CHECK-LABEL: @selectSame +func.func @selectSame(%cond: i1) -> !stream.timepoint { + %imm0 = stream.timepoint.immediate => !stream.timepoint + %imm1 = stream.timepoint.immediate => !stream.timepoint + // CHECK: arith.select + // CHECK-NEXT: %[[SELECT_IMM:.+]] = stream.timepoint.immediate + %select = arith.select %cond, %imm0, %imm1 : !stream.timepoint + // CHECK: return %[[SELECT_IMM]] + return %select : !stream.timepoint +} + +// ----- + +// Tests that a select with one or more unknown value is not marked immediate. + +// CHECK-LABEL: @selectDifferent +func.func @selectDifferent(%cond: i1, %unknown: !stream.timepoint) -> !stream.timepoint { + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[SELECT:.+]] = arith.select + %select = arith.select %cond, %imm, %unknown : !stream.timepoint + // CHECK: return %[[SELECT]] + return %select : !stream.timepoint +} + +// ----- + +// Tests global immediate timepoints are marked immediate when loaded. + +util.global private mutable @global = #stream.timepoint : !stream.timepoint + +// CHECK-LABEL: @immediateGlobal +func.func private @immediateGlobal() -> !stream.timepoint { + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + %global = util.global.load @global : !stream.timepoint + // CHECK: return %[[IMM]] + return %global : !stream.timepoint +} + +// ----- + +// Tests that uniform global store->load forwarding handles immediates. + +util.global private mutable @global : !stream.timepoint + +// CHECK-LABEL: @uniformGlobal +func.func private @uniformGlobal() -> !stream.timepoint { + %imm = stream.timepoint.immediate => !stream.timepoint + util.global.store %imm, @global : !stream.timepoint + // CHECK: util.global.load + %global = util.global.load @global : !stream.timepoint + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + // CHECK: return %[[IMM]] + return %global : !stream.timepoint +} +func.func private @globalSetter() { + %imm = stream.timepoint.immediate => !stream.timepoint + util.global.store %imm, @global : !stream.timepoint + return +} + +// ----- + +// Tests that divergent global stores do not propagate. + +util.global private mutable @global = #stream.timepoint : !stream.timepoint + +// CHECK-LABEL: @nonUniformGlobal +func.func private @nonUniformGlobal() -> !stream.timepoint { + // CHECK-NOT: stream.timepoint.immediate + // CHECK: %[[GLOBAL:.+]] = util.global.load @global + %global = util.global.load @global : !stream.timepoint + // CHECK: return %[[GLOBAL]] + return %global : !stream.timepoint +} +func.func @globalSetter(%arg0: !stream.timepoint) { + util.global.store %arg0, @global : !stream.timepoint + return +} + +// ----- + +// Tests that meaningful timeline ops are never marked immediate. + +// CHECK-LABEL: @nonImmediate +func.func private @nonImmediate() -> !stream.timepoint { + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[EXECUTE:.+]] = stream.cmd.execute + %0 = stream.cmd.execute await(%imm) => with() {} => !stream.timepoint + // CHECK: return %[[EXECUTE]] + return %0 : !stream.timepoint +} + +// ----- + +// Tests that an immediate timepoint passed along a call edge is propagated. + +// CHECK-LABEL: func @caller +func.func @caller() -> !stream.timepoint { + // CHECK: %[[T0_IMM:.+]] = stream.timepoint.immediate + %t0 = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[T1:.+]] = call @callee(%[[T0_IMM]], %[[T0_IMM]]) + // CHECK-NEXT: %[[T1_IMM:.+]] = stream.timepoint.immediate + %t1 = call @callee(%t0, %t0) : (!stream.timepoint, !stream.timepoint) -> !stream.timepoint + // CHECK: %[[T2:.+]] = stream.timepoint.join max(%[[T0_IMM]], %[[T1_IMM]]) + // CHECK-NEXT: %[[T2_IMM:.+]] = stream.timepoint.immediate + %t2 = stream.timepoint.join max(%t0, %t1) => !stream.timepoint + // CHECK: return %[[T2_IMM]] + return %t2 : !stream.timepoint +} +// CHECK-LABEL: func private @callee +func.func private @callee(%t0a: !stream.timepoint, %t0b: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[T0A_IMM:.+]] = stream.timepoint.immediate + // CHECK: %[[T0B_IMM:.+]] = stream.timepoint.immediate + // CHECK-NEXT: %[[T1:.+]] = stream.timepoint.join max(%[[T0A_IMM]], %[[T0B_IMM]]) + %t1 = stream.timepoint.join max(%t0a, %t0b) => !stream.timepoint + // CHECK-NEXT: %[[T1_IMM:.+]] = stream.timepoint.immediate + // CHECK-NEXT: return %[[T1_IMM]] + return %t1 : !stream.timepoint +} + +// ----- + +// Tests that duplicate call args/results are handled correctly. + +// CHECK-LABEL: func @callerDupes +func.func @callerDupes() -> !stream.timepoint { + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[CALL:.+]]:2 = call @calleeDupes + // CHECK-NEXT: %[[CALL_IMM0:.+]] = stream.timepoint.immediate + // CHECK-NEXT: %[[CALL_IMM1:.+]] = stream.timepoint.immediate + %call:2 = call @calleeDupes(%imm, %imm) : (!stream.timepoint, !stream.timepoint) -> (!stream.timepoint, !stream.timepoint) + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[CALL_IMM0]], %[[CALL_IMM1]]) + // CHECK-NEXT: %[[JOIN_IMM:.+]] = stream.timepoint.immediate + %join = stream.timepoint.join max(%call#0, %call#1) => !stream.timepoint + // CHECK: return %[[JOIN_IMM]] + return %join : !stream.timepoint +} +func.func private @calleeDupes(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> (!stream.timepoint, !stream.timepoint) { + return %arg0, %arg1 : !stream.timepoint, !stream.timepoint +} + +// ----- + +// Tests that convergent caller timepoints are handled correctly. + +// CHECK-LABEL: func @uniformCaller +func.func @uniformCaller() -> !stream.timepoint { + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK-NEXT: call @uniformCallee(%[[IMM]]) + // CHECK-NEXT: %[[CALL_IMM0:.+]] = stream.timepoint.immediate + %call0 = call @uniformCallee(%imm) : (!stream.timepoint) -> !stream.timepoint + // CHECK-NEXT: call @uniformCallee(%[[IMM]]) + // CHECK-NEXT: %[[CALL_IMM1:.+]] = stream.timepoint.immediate + %call1 = call @uniformCallee(%imm) : (!stream.timepoint) -> !stream.timepoint + // CHECK-NEXT: %[[CALLER_JOIN:.+]] = stream.timepoint.join max(%[[CALL_IMM0]], %[[CALL_IMM1]]) + // CHECK-NEXT: %[[CALLER_JOIN_IMM:.+]] = stream.timepoint.immediate + %join = stream.timepoint.join max(%call0, %call1) => !stream.timepoint + // CHECK: return %[[CALLER_JOIN_IMM]] + return %join : !stream.timepoint +} +// CHECK: func private @uniformCallee +func.func private @uniformCallee(%arg0: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[ARG0_IMM:.+]] = stream.timepoint.immediate + // CHECK: %[[CALLEE_JOIN:.+]] = stream.timepoint.join max(%[[ARG0_IMM]]) + // CHECK-NEXT: %[[CALLEE_JOIN_IMM:.+]] = stream.timepoint.immediate + %0 = stream.timepoint.join max(%arg0) => !stream.timepoint + // CHECK: return %[[CALLEE_JOIN_IMM]] + return %0 : !stream.timepoint +} + +// ----- + +// Tests that divergent caller timepoints are handled correctly. +// NOTE: if we ever implemented execution tracing we could discover that %call1 +// should be immediate - today, though, we aggregate over callers and any one +// that may pass a non-immediate poisons the analysis. + +// CHECK-LABEL: func @nonUniformCaller +// CHECK-SAME: (%[[UNKNOWN:.+]]: !stream.timepoint) +func.func @nonUniformCaller(%unknown: !stream.timepoint) -> !stream.timepoint { + // CHECK-NOT: stream.timepoint.immediate + // CHECK: %[[CALL0:.+]] = call @nonUniformCallee(%[[UNKNOWN]]) + %call0 = call @nonUniformCallee(%unknown) : (!stream.timepoint) -> !stream.timepoint + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[CALL1:.+]] = call @nonUniformCallee(%[[IMM]]) + %call1 = call @nonUniformCallee(%imm) : (!stream.timepoint) -> !stream.timepoint + // CHECK: %[[CALLER_JOIN:.+]] = stream.timepoint.join max(%[[CALL0]], %[[CALL1]]) + %join = stream.timepoint.join max(%call0, %call1) => !stream.timepoint + // CHECK: return %[[CALLER_JOIN]] + return %join : !stream.timepoint +} +// CHECK-LABEL: func private @nonUniformCallee +// CHECK-SAME: (%[[CALLEE_ARG:.+]]: !stream.timepoint) +func.func private @nonUniformCallee(%arg0: !stream.timepoint) -> !stream.timepoint { + // CHECK-NOT: stream.timepoint.immediate + // CHECK: %[[CALLEE_JOIN:.+]] = stream.timepoint.join max(%[[CALLEE_ARG]]) + %0 = stream.timepoint.join max(%arg0) => !stream.timepoint + // CHECK: return %[[CALLEE_JOIN]] + return %0 : !stream.timepoint +} + +// ----- + +// Tests that an immediate timepoint passed along a block edge is propagated. + +// CHECK-LABEL: func @branch +func.func @branch() -> !stream.timepoint { + %t0 = stream.timepoint.immediate => !stream.timepoint + // CHECK: cf.br ^bb1 + cf.br ^bb1(%t0 : !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_T0:.+]]: !stream.timepoint) +^bb1(%bb1_t0: !stream.timepoint): + // CHECK-NEXT: %[[BB1_T0_IMMEDIATE:.+]] = stream.timepoint.immediate + // CHECK-NEXT: %[[T1:.+]] = stream.timepoint.join max(%[[BB1_T0_IMMEDIATE]]) + %t1 = stream.timepoint.join max(%bb1_t0) => !stream.timepoint + // CHECK-NEXT: %[[JOIN_IMMEDIATE:.+]] = stream.timepoint.immediate + // CHECK-NEXT: return %[[JOIN_IMMEDIATE]] + return %t1 : !stream.timepoint +} + +// ----- + +// Tests that forward edges with convergently immediate timepoints get +// propagated. + +// CHECK-LABEL: func @branchConvergentForwardEdge +func.func @branchConvergentForwardEdge(%cond: i1) -> !stream.timepoint { + // CHECK: %[[IMM0:.+]] = stream.timepoint.immediate + %imm0 = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[IMM1:.+]] = stream.timepoint.immediate + %imm1 = stream.timepoint.immediate => !stream.timepoint + // CHECK-NEXT: cf.cond_br + // CHECK-SAME: ^bb1(%[[IMM0]] : !stream.timepoint), + // CHECK-SAME: ^bb1(%[[IMM1]] : !stream.timepoint) + cf.cond_br %cond, ^bb1(%imm0 : !stream.timepoint), ^bb1(%imm1 : !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_ARG:.+]]: !stream.timepoint) +^bb1(%bb1_arg: !stream.timepoint): + // CHECK: %[[BB1_IMM:.+]] = stream.timepoint.immediate + // CHECK: return %[[BB1_IMM]] + return %bb1_arg : !stream.timepoint +} + +// ----- + +// Tests that forward edges with divergent timepoints don't get propagated. + +// CHECK-LABEL: func @branchDivergentForwardEdge +// CHECK-SAME: (%[[COND:.+]]: i1, %[[UNKNOWN:.+]]: !stream.timepoint) +func.func @branchDivergentForwardEdge(%cond: i1, %unknown: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK-NEXT: cf.cond_br %[[COND]] + // CHECK-SAME: ^bb1(%[[UNKNOWN]] : !stream.timepoint), + // CHECK-SAME: ^bb1(%[[IMM]] : !stream.timepoint) + cf.cond_br %cond, ^bb1(%unknown : !stream.timepoint), ^bb1(%imm : !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_ARG:.+]]: !stream.timepoint) +^bb1(%bb1_arg: !stream.timepoint): + // CHECK: return %[[BB1_ARG]] + return %bb1_arg : !stream.timepoint +} + +// ----- + +// Tests that back edges with divergent timepoints don't get propagated. + +// CHECK-LABEL: func @branchDivergentBackEdge +func.func @branchDivergentBackEdge(%cond: i1) -> !stream.timepoint { + %t0 = stream.timepoint.immediate => !stream.timepoint + // CHECK: cf.br ^bb1 + cf.br ^bb1(%cond, %t0 : i1, !stream.timepoint) +// CHECK-NEXT: ^bb1(%[[BB1_COND:.+]]: i1, %[[BB1_T0:.+]]: !stream.timepoint) +^bb1(%bb1_cond: i1, %bb1_t0: !stream.timepoint): + // CHECK-NOT: stream.timepoint.immediate + // CHECK-NEXT: %[[BB1_T1:.+]] = stream.cmd.execute await(%[[BB1_T0]]) + %bb1_t1 = stream.cmd.execute await(%bb1_t0) => with() {} => !stream.timepoint + // CHECK: %[[FALSE:.+]] = arith.constant false + %cond_false = arith.constant false + // CHECK-NEXT: cf.cond_br + // CHECK-SAME: ^bb1(%[[FALSE]], %[[BB1_T1]] : i1, !stream.timepoint) + // CHECK-SAME: ^bb2(%[[BB1_T1]] : !stream.timepoint) + cf.cond_br %bb1_cond, ^bb1(%cond_false, %bb1_t1 : i1, !stream.timepoint), ^bb2(%bb1_t1 : !stream.timepoint) +// CHECK-NEXT: ^bb2(%[[BB2_T1:.+]]: !stream.timepoint) +^bb2(%bb2_t1: !stream.timepoint): + // CHECK-NEXT: return %[[BB2_T1]] + return %bb2_t1 : !stream.timepoint +} + +// ----- + +// Tests that scf.if regions with convergent yields are handled. + +// CHECK-LABEL: func @scfIfConvergent +// CHECK-SAME: (%[[COND:.+]]: i1) +func.func @scfIfConvergent(%cond: i1) -> !stream.timepoint { + // CHECK: %[[IF:.+]] = scf.if + %if = scf.if %cond -> !stream.timepoint { + // CHECK: %[[IMM0:.+]] = stream.timepoint.immediate + %imm0 = stream.timepoint.immediate => !stream.timepoint + // CHECK: yield %[[IMM0]] + scf.yield %imm0 : !stream.timepoint + } else { + // CHECK: %[[IMM1:.+]] = stream.timepoint.immediate + %imm1 = stream.timepoint.immediate => !stream.timepoint + // CHECK: yield %[[IMM1]] + scf.yield %imm1 : !stream.timepoint + } + // CHECK: %[[IF_IMM:.+]] = stream.timepoint.immediate + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[IF_IMM]]) + // CHECK-NEXT: %[[JOIN_IMM:.+]] = stream.timepoint.immediate + %join = stream.timepoint.join max(%if) => !stream.timepoint + // CHECK: return %[[JOIN_IMM]] + return %join : !stream.timepoint +} + +// ----- + +// Tests that scf.if regions with divergent yields are handled. + +// CHECK-LABEL: func @scfIfDivergent +// CHECK-SAME: (%[[COND:.+]]: i1, %[[UNKNOWN:.+]]: !stream.timepoint) +func.func @scfIfDivergent(%cond: i1, %unknown: !stream.timepoint) -> !stream.timepoint { + // CHECK: %[[IMM:.+]] = stream.timepoint.immediate + %imm = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[IF:.+]] = scf.if + %0 = scf.if %cond -> !stream.timepoint { + // CHECK: yield %[[IMM]] + scf.yield %imm : !stream.timepoint + } else { + // CHECK: %[[JOIN1:.+]] = stream.timepoint.join max(%[[UNKNOWN]], %[[IMM]]) + %join1 = stream.timepoint.join max(%unknown, %imm) => !stream.timepoint + // CHECK: yield %[[JOIN1]] + scf.yield %join1 : !stream.timepoint + } + // CHECK-NOT: stream.timepoint.immediate + // CHECK: %[[JOIN_OUTER:.+]] = stream.timepoint.join max(%[[UNKNOWN]], %[[IF]]) + %join_outer = stream.timepoint.join max(%unknown, %0) => !stream.timepoint + // CHECK: return %[[JOIN_OUTER]] + return %join_outer : !stream.timepoint +} + +// TODO(benvanik): support scf.for diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.cpp index c7543449459d..dd1099d638b1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.cpp @@ -57,7 +57,7 @@ ChangeStatus ConstantAttributePVS::updateValue(Value value, return DFX::clampStateAndIndicateChange(getState(), newState); } -const std::string ConstantAttributePVS::getAsStr() const { +const std::string ConstantAttributePVS::getAsStr(AsmState &asmState) const { std::string str; llvm::raw_string_ostream sstream(str); sstream << "pvs: "; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.h index 1a4f760117e9..e87d27ae785e 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/PotentialValues.h @@ -44,7 +44,7 @@ class ConstantAttributePVS } static const char ID; - const std::string getAsStr() const override; + const std::string getAsStr(AsmState &asmState) const override; private: void initializeValue(Value value, DFX::Solver &solver) override; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp index 5ed240b7a86c..de65ac194a3d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp @@ -45,7 +45,7 @@ void FloatRangeStats::addDomainValue(double value) { } } -std::string FloatRangeStats::getAsStr() const { +std::string FloatRangeStats::getAsStr(AsmState &asmState) const { if (!valid) return std::string("<>"); std::string s("["); s += std::to_string(minValue); @@ -157,8 +157,8 @@ ChangeStatus FloatRangeValueElement::updateValue(Value value, stats.addDomainValue(elementValue.convertToDouble()); } newState.setAssumed(stats); - LLVM_DEBUG(dbgs() << "*** COMPUTED KNOWN RANGE: " << stats.getAsStr() - << "\n"); + LLVM_DEBUG(dbgs() << "*** COMPUTED KNOWN RANGE: " + << stats.getAsStr(solver.getAsmState()) << "\n"); newState.indicateOptimisticFixpoint(); } else { // Unknown. @@ -206,9 +206,12 @@ ChangeStatus FloatRangeValueElement::updateValue(Value value, DFX::Resolution::REQUIRED); newState.applyMinf(lhs.getAssumed(), rhs.getAssumed()); - LLVM_DEBUG(dbgs() << "VISITING minf: lhs = " << lhs.getAsStr() - << ", rhs = " << rhs.getAsStr() << " -> " - << newState.getAssumed().getAsStr() << "\n"); + LLVM_DEBUG(dbgs() + << "VISITING minf: lhs = " + << lhs.getAsStr(solver.getAsmState()) << ", rhs = " + << rhs.getAsStr(solver.getAsmState()) << " -> " + << newState.getAssumed().getAsStr(solver.getAsmState()) + << "\n"); return WalkResult::advance(); } else if (auto maxfOp = dyn_cast(definingOp)) { auto lhs = solver.getElementFor( @@ -219,9 +222,12 @@ ChangeStatus FloatRangeValueElement::updateValue(Value value, DFX::Resolution::REQUIRED); newState.applyMaxf(lhs.getAssumed(), rhs.getAssumed()); - LLVM_DEBUG(dbgs() << "VISITING maxf: lhs = " << lhs.getAsStr() - << ", rhs = " << rhs.getAsStr() << " -> " - << newState.getAssumed().getAsStr() << "\n"); + LLVM_DEBUG(dbgs() + << "VISITING maxf: lhs = " + << lhs.getAsStr(solver.getAsmState()) << ", rhs = " + << rhs.getAsStr(solver.getAsmState()) << " -> " + << newState.getAssumed().getAsStr(solver.getAsmState()) + << "\n"); return WalkResult::advance(); } else if (auto floorOp = dyn_cast(definingOp)) { auto operand = solver.getElementFor( @@ -229,8 +235,10 @@ ChangeStatus FloatRangeValueElement::updateValue(Value value, DFX::Resolution::REQUIRED); newState.applyFloor(operand.getAssumed()); LLVM_DEBUG(dbgs() - << "VISITING floor: " << operand.getAsStr() << " -> " - << newState.getAssumed().getAsStr() << "\n"); + << "VISITING floor: " + << operand.getAsStr(solver.getAsmState()) << " -> " + << newState.getAssumed().getAsStr(solver.getAsmState()) + << "\n"); return WalkResult::advance(); } @@ -247,10 +255,10 @@ ChangeStatus FloatRangeValueElement::updateValue(Value value, return DFX::clampStateAndIndicateChange(getState(), newState); } -const std::string FloatRangeValueElement::getAsStr() const { +const std::string FloatRangeValueElement::getAsStr(AsmState &asmState) const { auto range = getAssumed(); std::string s("fp-range: "); - s += range.getAsStr(); + s += range.getAsStr(asmState); return s; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.h index d8c640e108be..bd32c3519db7 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.h @@ -94,7 +94,7 @@ struct FloatRangeStats { } } - std::string getAsStr() const; + std::string getAsStr(AsmState &asmState) const; }; // State that tracks floating point ranges and flags. @@ -165,7 +165,7 @@ class FloatRangeValueElement static bool classof(const DFX::AbstractElement *element) { return (element->getID() == &ID); } - const std::string getAsStr() const override; + const std::string getAsStr(AsmState &asmState) const override; private: void initializeValue(Value value, DFX::Solver &solver) override; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp index f5d7e13a0e31..13272c226978 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp @@ -57,7 +57,7 @@ void AbstractElement::print(llvm::raw_ostream &os, AsmState &asmState) const { os << "<>"; } - os << " with state " << getAsStr(); + os << " with state " << getAsStr(asmState); } void AbstractElement::printWithDeps(llvm::raw_ostream &os, diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h index 0782d278d7d1..281b1c0acc48 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h @@ -66,7 +66,7 @@ class AbstractElement : public Position, public DepGraphNode { // Returns the human-friendly summarized assumed state as string for // debugging. - virtual const std::string getAsStr() const = 0; + virtual const std::string getAsStr(AsmState &asmState) const = 0; void print(llvm::raw_ostream &os, AsmState &asmState) const override; virtual void printWithDeps(llvm::raw_ostream &os, AsmState &asmState) const; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h index d92bc604f758..6e1f8ad836cd 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h @@ -451,6 +451,9 @@ struct PotentialValuesState : AbstractState { // Unions assumed set with an undef value. void unionAssumedWithUndef() { unionWithUndef(); } + // Intersects assumed set with assumed set of the passed state |rhs|. + void intersectAssumed(const PotentialValuesState &rhs) { intersectWith(rhs); } + // "Clamps" this state with |rhs|. PotentialValuesState operator^=(const PotentialValuesState &rhs) { validState ^= rhs.validState; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp index 818a479c0e3e..6c02429d8672 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp @@ -389,17 +389,21 @@ TraversalResult Explorer::walkIncomingCalls( TraversalResult Explorer::walkReturnOps(Operation *parentOp, OperationWalkFn fn) { LLVM_DEBUG(llvm::dbgs() << "[[ Explorer::walkReturnOps ]]\n"); + if (getTraversalAction(parentOp) != TraversalAction::RECURSE) { + LLVM_DEBUG(llvm::dbgs() << " -- ignoring region op " + << parentOp->getName().getStringRef() << "\n"); + return TraversalResult::COMPLETE; + } TraversalResult result = TraversalResult::COMPLETE; if (auto regionOp = dyn_cast(parentOp)) { auto enumerateTerminatorOps = [&](Region ®ion) { for (auto &block : region) { - for (auto terminatorOp : - block.getOps()) { + if (auto *terminatorOp = block.getTerminator()) { // TODO(benvanik): ensure this terminator can return to parent? this // region op interface confuses me. LLVM_DEBUG({ llvm::dbgs() << " == emitting region branch terminator op "; - terminatorOp.print(llvm::dbgs(), asmState); + terminatorOp->print(llvm::dbgs(), asmState); llvm::dbgs() << "\n"; }); return fn(terminatorOp); @@ -507,6 +511,16 @@ TraversalResult Explorer::walkIncomingBranchOperands( return result; } +TraversalResult Explorer::walkIncomingBlockArgument( + BlockArgument blockArg, + std::function fn) { + return walkIncomingBranchOperands( + blockArg.getParentBlock(), + [&](Block *sourceBlock, OperandRange operands) { + return fn(sourceBlock, operands[blockArg.getArgNumber()]); + }); +} + TraversalResult Explorer::walkOutgoingBranchArguments( Block *sourceBlock, std::function @@ -573,10 +587,6 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn) { auto *targetBlock = arg.getParentBlock(); return walkIncomingBranchOperands( targetBlock, [&](Block *sourceBlock, OperandRange operands) { - if (sourceBlock == targetBlock) { - // Recursion; ignore (?). - return WalkResult::advance(); - } auto branchOperand = operands[arg.getArgNumber()]; LLVM_DEBUG({ llvm::dbgs() << " + queuing "; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h index faffb27080d2..c58beeb4a3b2 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h @@ -133,6 +133,26 @@ class Explorer { bool isIndirect = false; // All loads and stores of the global across the program. SmallVector uses; + + // Returns a range of all direct loads of the global. + auto getLoads() const { + assert(!isIndirect && "indirect loads not yet tracked"); + return llvm::map_range( + llvm::make_filter_range( + uses, + [](Operation *op) { return isa(op); }), + [](Operation *op) { return cast(op); }); + } + + // Returns a range of all direct stores to the global. + auto getStores() const { + assert(!isIndirect && "indirect stores not yet tracked"); + return llvm::map_range( + llvm::make_filter_range( + uses, + [](Operation *op) { return isa(op); }), + [](Operation *op) { return cast(op); }); + } }; // Gets analyzed global information for the given global operation. @@ -208,6 +228,11 @@ class Explorer { Block *targetBlock, std::function fn); + // Walks all predecessor blocks providing values for |blockArg|. + TraversalResult walkIncomingBlockArgument( + BlockArgument blockArg, + std::function fn); + // Walks all successor blocks of |sourceBlock| and provides their arguments. // Note that |sourceBlock| may be enumerated if there is recursion. TraversalResult walkOutgoingBranchArguments( diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp index 76274f1e4051..b12c4a8e791f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp @@ -75,6 +75,25 @@ void populateUtilConversionPatterns(MLIRContext *context, namespace { +struct ConvertInitializerOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::Util::InitializerOp initializerOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto &typeConverter = *getTypeConverter(); + rewriter.startRootUpdate(initializerOp); + if (failed(rewriter.convertRegionTypes(&initializerOp.getBody(), + typeConverter))) { + rewriter.cancelRootUpdate(initializerOp); + return rewriter.notifyMatchFailure(initializerOp, + "failed to convert region types"); + } + rewriter.finalizeRootUpdate(initializerOp); + return success(); + } +}; + struct ConvertFuncOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( @@ -226,6 +245,10 @@ void populateGenericStructuralConversionPatterns( // We need to rewrite certain types on operands/results so use the default // dynamic legality checker to force any ops using such types to run through // our patterns. + conversionTarget.addDynamicallyLegalOp( + [&](IREE::Util::InitializerOp op) { + return typeConverter.isLegal(&op.getBody()); + }); conversionTarget.addDynamicallyLegalOp( [&](mlir::func::FuncOp op) { return typeConverter.isSignatureLegal(op.getFunctionType()) && @@ -238,9 +261,10 @@ void populateGenericStructuralConversionPatterns( addGenericLegalOp(conversionTarget, typeConverter); addGenericLegalOp(conversionTarget, typeConverter); addGenericLegalOp(conversionTarget, typeConverter); - patterns.insert(typeConverter, context); + patterns.insert(typeConverter, + context); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir index 23a593397e60..5afcba5b1ec2 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/structural_ops.mlir @@ -3,6 +3,20 @@ // These patterns are not doing anything dialect-specific and instead just // allowing for the ops to update their types during dialect conversions. +// CHECK: util.initializer +util.initializer { + // CHECK: %[[VALUE:.+]] = func.call @extern + %value = func.call @extern() : () -> memref + // CHECK: cf.br ^bb1(%[[VALUE]] : !util.buffer) + cf.br ^bb1(%value : memref) +// CHECK: ^bb1(%[[ARG:.+]]: !util.buffer) +^bb1(%block_arg: memref): + util.initializer.return +} +func.func private @extern() -> memref + +// ----- + // CHECK-LABEL: @funcOp // CHECK-SAME: (%[[ARG0:.+]]: !util.buffer) -> !util.buffer func.func @funcOp(%arg0: memref) -> memref { diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 77fa2d7b2fef..5857bcaf5f91 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -906,6 +906,14 @@ LogicalResult GlobalLoadIndirectOp::verify() { return success(); } +void GlobalStoreOp::build(OpBuilder &builder, OperationState &state, + Value value, GlobalOp globalOp, + ArrayRef attrs) { + state.addOperands({value}); + state.addAttribute("global", SymbolRefAttr::get(globalOp)); + state.attributes.append(attrs.begin(), attrs.end()); +} + IREE::Util::GlobalOp GlobalStoreOp::getGlobalOp( SymbolTableCollection &symbolTable) { return symbolTable.lookupNearestSymbolFrom( diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index d5c3e6dddfe7..13fcefa12b9d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -592,6 +592,14 @@ def Util_GlobalStoreOp : Util_Op<"global.store", [ $value `,` $global attr-dict `:` type($value) }]; + let builders = [ + OpBuilder<(ins + "Value":$value, + "IREE::Util::GlobalOp":$globalOp, + CArg<"ArrayRef", "{}">:$attributes + )>, + ]; + let extraClassDeclaration = [{ IREE::Util::GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); }]; diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp index b7c241079af4..fbe91777d7cb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp @@ -127,7 +127,10 @@ llvm::Optional detail::getTiedResultOperandIndex( auto valueAttrs = storageAttr.getValue(); if (valueAttrs.empty()) return llvm::None; auto tiedOp = cast(op); - resultIndex -= tiedOp.getTiedResultsIndexAndLength().first; + auto indexAndLength = tiedOp.getTiedResultsIndexAndLength(); + if (resultIndex < indexAndLength.first) return None; + resultIndex -= indexAndLength.first; + if (resultIndex >= indexAndLength.second) return None; int64_t value = valueAttrs[resultIndex].cast().getInt(); if (value == TiedOpInterface::kUntiedIndex) return llvm::None; unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubrange.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubrange.cpp index 979b3e816739..0868347e4c6d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubrange.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubrange.cpp @@ -206,14 +206,16 @@ static void expandSubranges(Operation *op, ExpandedGlobalMap &globalMap, // Recursively expands resources into (resource, size, offset, length) tuples // within the given |region|. All branches, ops, and nested regions will be // processed. -static void expandRegion(Region ®ion, ExpandedGlobalMap &globalMap, - IndexSet &indexSet, SubrangeMap subrangeMap) { +static void expandRegion(Region ®ion, bool canModifyEntryBlock, + ExpandedGlobalMap &globalMap, IndexSet &indexSet, + SubrangeMap subrangeMap) { if (region.empty()) return; // Update all block arguments. auto indexType = IndexType::get(region.getContext()); for (auto &block : region.getBlocks()) { if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) continue; + if (block.isEntryBlock() && !canModifyEntryBlock) continue; // Insert and build a list of expanded (resource, size, offset) tuples. SmallVector expansions; @@ -263,10 +265,11 @@ static void expandRegion(Region ®ion, ExpandedGlobalMap &globalMap, } // Recursively expands all regions on the op. -static void expandRegions(Operation *op, ExpandedGlobalMap &globalMap, - IndexSet &indexSet, SubrangeMap subrangeMap) { +static void expandRegions(Operation *op, bool canModifyEntryBlock, + ExpandedGlobalMap &globalMap, IndexSet &indexSet, + SubrangeMap subrangeMap) { for (auto ®ion : op->getRegions()) { - expandRegion(region, globalMap, indexSet, subrangeMap); + expandRegion(region, canModifyEntryBlock, globalMap, indexSet, subrangeMap); } } @@ -381,7 +384,8 @@ static void expandGlobalStoreOp(IREE::Util::GlobalStoreOp op, static void expandInitializerOp(IREE::Util::InitializerOp op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap &subrangeMap) { - expandRegion(op.getRegion(), globalMap, indexSet, subrangeMap); + expandRegion(op.getRegion(), /*canModifyEntryBlock=*/false, globalMap, + indexSet, subrangeMap); } // Returns true if |op| is either public and visible to external modules or @@ -410,7 +414,8 @@ static bool isPublicOrExternal(CallableOpInterface callableOp) { static void expandFuncOp(mlir::func::FuncOp op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap &subrangeMap) { // Ignore public/external function signatures but still convert regions. - if (!isPublicOrExternal(op)) { + bool canModifyEntryBlock = !isPublicOrExternal(op); + if (canModifyEntryBlock) { auto oldType = op.getFunctionType(); auto inputTypes = expandTypes(oldType.getInputs()); auto resultTypes = expandTypes(oldType.getResults()); @@ -419,7 +424,8 @@ static void expandFuncOp(mlir::func::FuncOp op, ExpandedGlobalMap &globalMap, op.setType(newType); } } - expandRegion(op.getRegion(), globalMap, indexSet, subrangeMap); + expandRegion(op.getRegion(), canModifyEntryBlock, globalMap, indexSet, + subrangeMap); } // Splits resource operands and results into (resource, resourceSize, @@ -565,11 +571,14 @@ static void expandSubranges(Operation *op, ExpandedGlobalMap &globalMap, // We could add an interface to ops we want to do this to, though, to at least // allow dialects to plug in. For now we just need SCF so this is hardcoded. if (auto ifOp = dyn_cast(op)) { - return expandRegions(ifOp, globalMap, indexSet, subrangeMap); + return expandRegions(ifOp, /*canModifyEntryBlock=*/false, globalMap, + indexSet, subrangeMap); } else if (auto forOp = dyn_cast(op)) { - return expandRegions(forOp, globalMap, indexSet, subrangeMap); + return expandRegions(forOp, /*canModifyEntryBlock=*/false, globalMap, + indexSet, subrangeMap); } else if (auto whileOp = dyn_cast(op)) { - return expandRegions(whileOp, globalMap, indexSet, subrangeMap); + return expandRegions(whileOp, /*canModifyEntryBlock=*/false, globalMap, + indexSet, subrangeMap); } // TODO(benvanik): also handle scf.yield: today we don't propagate across // return values. diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp index b5bfa98fd6c8..b2f610614ece 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp @@ -79,7 +79,8 @@ static void hoistImmutableLoads(Region ®ion, static bool doesOpBlockMotion(Operation *op) { return isa(op) || - op->hasTrait(); + op->hasTrait() || + op->hasTrait(); } static void moveOpUpInBlock(Block &block, Operation *op) { @@ -90,7 +91,7 @@ static void moveOpUpInBlock(Block &block, Operation *op) { } static void moveOpDownInBlock(Block &block, Operation *op) { - while (op->getNextNode() != block.getTerminator()) { + while (op->getNextNode()) { if (doesOpBlockMotion(op->getNextNode())) break; op->moveAfter(op->getNextNode()); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp index 297c28f423b6..f9a00094c1a4 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestFloatRangeAnalysis.cpp @@ -58,8 +58,10 @@ class TestFloatRangeAnalysisPass // Update. for (auto &it : queryOps) { - it.first->setAttr("analysis", - StringAttr::get(&getContext(), it.second->getAsStr())); + it.first->setAttr( + "analysis", + StringAttr::get(&getContext(), + it.second->getAsStr(solver.getAsmState()))); } } }; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/propagate_subranges.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/propagate_subranges.mlir index 052c7b0a4b41..c6be73b3cc2c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/propagate_subranges.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/propagate_subranges.mlir @@ -85,6 +85,18 @@ func.func private @funcResults(%resource0: !util.buffer, %resource1: !util.buffe return %resource0, %resource1 : !util.buffer, !util.buffer } + +// ----- + +// Tests that exported functions don't have their signature changed. + +// CHECK-LABEL: @publicFuncSignature +// CHECK-SAME: (%[[RESOURCE:.+]]: !util.buffer) -> !util.buffer +func.func @publicFuncSignature(%resource: !util.buffer) -> !util.buffer { + // CHECK-NEXT: return %[[RESOURCE]] : !util.buffer + return %resource : !util.buffer +} + // ----- // Tests that function calls have their args and results expanded into diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h index cbe0eab8749e..3c5eeb1a574f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h @@ -32,7 +32,7 @@ struct EncodedBytecodeFunction { class BytecodeEncoder : public VMFuncEncoder { public: // Matches IREE_VM_BYTECODE_VERSION_MAJOR. - static constexpr uint32_t kVersionMajor = 5; + static constexpr uint32_t kVersionMajor = 6; // Matches IREE_VM_BYTECODE_VERSION_MINOR. static constexpr uint32_t kVersionMinor = 0; static constexpr uint32_t kVersion = (kVersionMajor << 16) | kVersionMinor; diff --git a/runtime/src/iree/hal/fence.h b/runtime/src/iree/hal/fence.h index 6cedf95bd649..fdb8b168ce97 100644 --- a/runtime/src/iree/hal/fence.h +++ b/runtime/src/iree/hal/fence.h @@ -16,6 +16,10 @@ extern "C" { #endif // __cplusplus +//===----------------------------------------------------------------------===// +// iree_hal_fence_t +//===----------------------------------------------------------------------===// + // A list of semaphores and their corresponding payloads. // When signaling each semaphore will be set to the new payload value provided. // When waiting each semaphore must reach or exceed the payload value. @@ -124,6 +128,10 @@ IREE_API_EXPORT iree_status_t iree_hal_fence_wait(iree_hal_fence_t* fence, IREE_API_EXPORT iree_wait_source_t iree_hal_fence_await(iree_hal_fence_t* fence); +//===----------------------------------------------------------------------===// +// iree_hal_fence_t implementation details +//===----------------------------------------------------------------------===// + IREE_API_EXPORT void iree_hal_fence_destroy(iree_hal_fence_t* fence); #ifdef __cplusplus diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index 901e80876516..988841f8e0ee 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -62,9 +62,12 @@ EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_ EXPORT_FN("device.allocator", iree_hal_module_device_allocator, r, r) EXPORT_FN("device.query.i64", iree_hal_module_device_query_i64, rrr, iI) +EXPORT_FN("device.queue.alloca", iree_hal_module_device_queue_alloca, rIrriiiI, r) +EXPORT_FN("device.queue.dealloca", iree_hal_module_device_queue_dealloca, rIrrr, v) +EXPORT_FN("device.queue.execute", iree_hal_module_device_queue_execute, rIrrCrD, v) +EXPORT_FN("device.queue.flush", iree_hal_module_device_queue_flush, rI, v) EXPORT_FN("ex.shared_device", iree_hal_module_ex_shared_device, v, r) -EXPORT_FN("ex.submit_and_wait", iree_hal_module_ex_submit_and_wait, rr, v) EXPORT_FN("executable.create", iree_hal_module_executable_create, rrrrCrD, r) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index 3701aa843b69..df240efc2429 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -37,13 +37,27 @@ typedef struct iree_hal_module_t { typedef struct iree_hal_module_state_t { iree_allocator_t host_allocator; + + // Flags controlling HAL module behavior passed in from the hosting + // application. All instantiations of a module share the same flags. iree_hal_module_flags_t flags; + + // HACK: today we only support a single device per context - in the future + // this should be a set of available devices that the module is able to pick + // from - the module will then hang on to them and use them as native globals + // instead of storing anything in module state here. iree_hal_device_t* shared_device; + + // TODO(benvanik): add iree_loop_t to module constructor. + // Status of the nested loop we run for executable creation today. We should + // instead be taking a loop upon creation and scheduling work against that. iree_status_t loop_status; - iree_hal_executable_cache_t* executable_cache; - iree_hal_semaphore_t* submit_semaphore; - uint64_t submit_value; + // Shared executable cache for all executables created in the context. + // We could have multiple to allow for modules to create distinct sets of + // executables like ones for training vs inference in the same model, or just + // always use this. + iree_hal_executable_cache_t* executable_cache; } iree_hal_module_state_t; static void IREE_API_PTR iree_hal_module_destroy(void* base_module) { @@ -73,11 +87,6 @@ iree_hal_module_alloc_state(void* self, iree_allocator_t host_allocator, state->shared_device, iree_string_view_empty(), iree_loop_inline(&state->loop_status), &state->executable_cache)); - state->submit_value = 0ull; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_semaphore_create(state->shared_device, state->submit_value, - &state->submit_semaphore)); - *out_module_state = (iree_vm_module_state_t*)state; IREE_TRACE_ZONE_END(z0); return iree_ok_status(); @@ -88,7 +97,6 @@ iree_hal_module_free_state(void* self, iree_vm_module_state_t* module_state) { IREE_TRACE_ZONE_BEGIN(z0); iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state; - iree_hal_semaphore_release(state->submit_semaphore); iree_hal_executable_cache_release(state->executable_cache); iree_status_ignore(state->loop_status); iree_hal_device_release(state->shared_device); @@ -122,40 +130,6 @@ IREE_VM_ABI_EXPORT(iree_hal_module_ex_shared_device, // return iree_ok_status(); } -IREE_VM_ABI_EXPORT(iree_hal_module_ex_submit_and_wait, // - iree_hal_module_state_t, // - rr, v) { - iree_hal_device_t* device = NULL; - IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); - iree_hal_command_buffer_t* command_buffer = NULL; - IREE_RETURN_IF_ERROR( - iree_hal_command_buffer_check_deref(args->r1, &command_buffer)); - - // Batch with our single command buffer. - iree_hal_submission_batch_t batch; - memset(&batch, 0, sizeof(batch)); - - iree_hal_command_buffer_t* command_buffer_ptrs[] = {command_buffer}; - batch.command_buffer_count = IREE_ARRAYSIZE(command_buffer_ptrs); - batch.command_buffers = command_buffer_ptrs; - - uint64_t next_semaphore_value = ++state->submit_value; - iree_hal_semaphore_t* signal_semaphore_ptrs[] = {state->submit_semaphore}; - uint64_t signal_semaphore_values[] = {next_semaphore_value}; - batch.signal_semaphores.count = IREE_ARRAYSIZE(signal_semaphore_ptrs); - batch.signal_semaphores.semaphores = signal_semaphore_ptrs; - batch.signal_semaphores.payload_values = signal_semaphore_values; - - iree_status_t status = iree_hal_device_queue_submit( - device, IREE_HAL_COMMAND_CATEGORY_ANY, 0, 1, &batch); - if (iree_status_is_ok(status)) { - status = iree_hal_semaphore_wait( - state->submit_semaphore, next_semaphore_value, iree_infinite_timeout()); - } - - return status; -} - //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// @@ -868,6 +842,112 @@ IREE_VM_ABI_EXPORT(iree_hal_module_device_query_i64, // return iree_ok_status(); } +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_alloca, // + iree_hal_module_state_t, // + rIrriiiI, r) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + uint32_t pool = args->i4; + iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i5; + iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i6; + iree_device_size_t allocation_size = iree_hal_cast_device_size(args->i7); + + // TODO(benvanik): HAL APIs for queue-ordered allocations. + // For now we just perform a blocking wait to synchronize with the queue, + // allocate the buffer as normal, and then pass it back committed. + (void)queue_affinity; + IREE_RETURN_IF_ERROR( + iree_hal_fence_wait(wait_fence, iree_infinite_timeout())); + + // TODO(benvanik): enforce queue-ordered allocation restrictions on memory + // type and usage. + (void)pool; + + const iree_hal_buffer_params_t params = { + .type = memory_types, + .usage = buffer_usage, + }; + iree_hal_buffer_t* buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + iree_hal_device_allocator(device), params, allocation_size, + iree_const_byte_span_empty(), &buffer)); + + IREE_RETURN_IF_ERROR(iree_hal_fence_signal(signal_fence)); + + rets->r0 = iree_hal_buffer_move_ref(buffer); + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_dealloca, // + iree_hal_module_state_t, // + rIrrr, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + iree_hal_buffer_t* buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r4, &buffer)); + + // TODO(benvanik): HAL APIs for queue-ordered allocations. + // For now we just perform a blocking wait to synchronize with the queue and + // then ignore the buffer for GC to cleanup. + (void)queue_affinity; + IREE_RETURN_IF_ERROR( + iree_hal_fence_wait(wait_fence, iree_infinite_timeout())); + IREE_RETURN_IF_ERROR(iree_hal_fence_signal(signal_fence)); + + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_execute, // + iree_hal_module_state_t, // + rIrrCrD, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + iree_host_size_t command_buffer_count = 0; + iree_hal_command_buffer_t** command_buffers = NULL; + IREE_VM_ABI_VLA_STACK_DEREF(args, a4_count, a4, iree_hal_command_buffer, 32, + &command_buffer_count, &command_buffers); + + iree_hal_submission_batch_t batch = { + .wait_semaphores = iree_hal_fence_semaphore_list(wait_fence), + .signal_semaphores = iree_hal_fence_semaphore_list(signal_fence), + .command_buffer_count = command_buffer_count, + .command_buffers = command_buffers, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_submit( + device, IREE_HAL_COMMAND_CATEGORY_ANY, queue_affinity, 1, &batch)); + + return iree_ok_status(); +} + +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_flush, // + iree_hal_module_state_t, // + rI, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + + // TODO(benvanik): queue flush API. + // This will be most useful for backends that perform internal batching and + // require the explicit flush. For now we don't have this exposed. + (void)device; + (void)queue_affinity; + + return iree_ok_status(); +} + //===--------------------------------------------------------------------===// // iree_hal_executable_t //===--------------------------------------------------------------------===// @@ -1001,8 +1081,8 @@ IREE_VM_ABI_EXPORT(iree_hal_module_fence_join, // CrD, r) { iree_host_size_t fence_count = 0; iree_hal_fence_t** fences = NULL; - IREE_VM_ABI_VLA_STACK_DEREF(args, a0_count, a0, iree_hal_fence, 32, - &fence_count, &fences); + IREE_VM_ABI_VLA_STACK_DEREF_OR_NULL(args, a0_count, a0, iree_hal_fence, 32, + &fence_count, &fences); iree_hal_fence_t* fence = NULL; IREE_RETURN_IF_ERROR( @@ -1161,8 +1241,8 @@ IREE_VM_ABI_EXPORT(iree_hal_module_fence_await, // uint32_t timeout_millis = (uint32_t)args->i0; iree_host_size_t fence_count = 0; iree_hal_fence_t** fences = NULL; - IREE_VM_ABI_VLA_STACK_DEREF(args, a1_count, a1, iree_hal_fence, 32, - &fence_count, &fences); + IREE_VM_ABI_VLA_STACK_DEREF_OR_NULL(args, a1_count, a1, iree_hal_fence, 32, + &fence_count, &fences); IREE_TRACE_ZONE_BEGIN(z0); zone_id = z0; @@ -1189,11 +1269,11 @@ IREE_VM_ABI_EXPORT(iree_hal_module_fence_await, // if (!iree_status_is_ok(wait_status)) break; } } else { + current_frame->pc = IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME; IREE_RETURN_AND_END_ZONE_IF_ERROR( zone_id, iree_hal_module_fence_await_begin(stack, fence_count, fences, timeout, zone_id, &wait_status)); - current_frame->pc = IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME; if (iree_status_is_deferred(wait_status)) { zone_id = 0; // ownership transferred to wait frame } @@ -1455,7 +1535,8 @@ IREE_API_EXPORT iree_status_t iree_hal_module_create( iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module); module->host_allocator = host_allocator; - module->flags = flags; + // TODO(benvanik): fix vm yield with result storage. + module->flags = flags | IREE_HAL_MODULE_FLAG_SYNCHRONOUS; module->shared_device = device; iree_hal_device_retain(module->shared_device); diff --git a/runtime/src/iree/vm/bytecode_dispatch.c b/runtime/src/iree/vm/bytecode_dispatch.c index 12ac078d1608..db38d9e90c48 100644 --- a/runtime/src/iree/vm/bytecode_dispatch.c +++ b/runtime/src/iree/vm/bytecode_dispatch.c @@ -487,6 +487,11 @@ static iree_status_t iree_vm_bytecode_issue_import_call( iree_status_t call_status = call.function.module->begin_call(call.function.module->self, stack, call); if (iree_status_is_deferred(call_status)) { + if (!iree_byte_span_is_empty(call.results)) { + iree_status_ignore(call_status); + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "yield in imports with results not supported"); + } return call_status; // deferred for future resume } else if (IREE_UNLIKELY(!iree_status_is_ok(call_status))) { // TODO(benvanik): set execution result to failure/capture stack. diff --git a/runtime/src/iree/vm/bytecode_module_impl.h b/runtime/src/iree/vm/bytecode_module_impl.h index d51595a897a7..aea5d9a329cf 100644 --- a/runtime/src/iree/vm/bytecode_module_impl.h +++ b/runtime/src/iree/vm/bytecode_module_impl.h @@ -33,7 +33,7 @@ extern "C" { // Major bytecode version; mismatches on this will fail in either direction. // This allows coarse versioning of completely incompatible versions. // Matches BytecodeEncoder::kVersionMajor in the compiler. -#define IREE_VM_BYTECODE_VERSION_MAJOR 5 +#define IREE_VM_BYTECODE_VERSION_MAJOR 6 // Minor bytecode version; lower versions are allowed to enable newer runtimes // to load older serialized files when there are backwards-compatible changes. // Higher versions are disallowed as they occur when new ops are added that diff --git a/runtime/src/iree/vm/ref.h b/runtime/src/iree/vm/ref.h index 5ee63436c317..e21addbb50a1 100644 --- a/runtime/src/iree/vm/ref.h +++ b/runtime/src/iree/vm/ref.h @@ -265,6 +265,8 @@ struct ref_type_descriptor { IREE_API_EXPORT T* name##_deref(const iree_vm_ref_t ref); \ IREE_API_EXPORT iree_status_t name##_check_deref(const iree_vm_ref_t ref, \ T** out_ptr); \ + IREE_API_EXPORT iree_status_t name##_check_deref_or_null( \ + const iree_vm_ref_t ref, T** out_ptr); \ IREE_API_EXPORT const iree_vm_ref_type_descriptor_t* \ name##_get_descriptor(); \ static inline bool name##_isa(const iree_vm_ref_t ref) { \ @@ -286,9 +288,8 @@ struct ref_type_descriptor { return ref; \ } \ IREE_API_EXPORT T* name##_deref(const iree_vm_ref_t ref) { \ - iree_status_t status = iree_vm_ref_check(ref, name##_descriptor.type); \ - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { \ - IREE_IGNORE_ERROR(status); \ + if (IREE_UNLIKELY(ref.type != ref.type) || \ + IREE_UNLIKELY(ref.type == IREE_VM_REF_TYPE_NULL)) { \ return NULL; \ } \ return (T*)ref.ptr; \ @@ -299,6 +300,16 @@ struct ref_type_descriptor { *out_ptr = (T*)ref.ptr; \ return iree_ok_status(); \ } \ + IREE_API_EXPORT iree_status_t name##_check_deref_or_null( \ + const iree_vm_ref_t ref, T** out_ptr) { \ + if (ref.type != IREE_VM_REF_TYPE_NULL) { \ + IREE_RETURN_IF_ERROR(iree_vm_ref_check(ref, name##_descriptor.type)); \ + *out_ptr = (T*)ref.ptr; \ + } else { \ + *out_ptr = NULL; \ + } \ + return iree_ok_status(); \ + } \ IREE_API_EXPORT const iree_vm_ref_type_descriptor_t* \ name##_get_descriptor() { \ return &name##_descriptor; \ diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index 41e77947d1f0..9605efa0e13b 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c @@ -56,6 +56,9 @@ IREE_VM_ABI_DEFINE_SHIM(rrirCID, v); IREE_VM_ABI_DEFINE_SHIM(rrirI, v); IREE_VM_ABI_DEFINE_SHIM(rrIrII, v); IREE_VM_ABI_DEFINE_SHIM(rrrIii, v); +IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r); +IREE_VM_ABI_DEFINE_SHIM(rIrrr, v); +IREE_VM_ABI_DEFINE_SHIM(rIrrCrD, v); IREE_VM_ABI_DEFINE_SHIM(CrID, r); IREE_VM_ABI_DEFINE_SHIM(CrD, r); IREE_VM_ABI_DEFINE_SHIM(iCrD, i); diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index acd30eccff23..6e27ff96d2f3 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h @@ -138,6 +138,21 @@ typedef iree_status_t(IREE_API_PTR* iree_vm_native_function_target2_t)( ref_type##_check_deref((args)->vla_field[i].r0, &(*(out_ptrs))[i])); \ } +#define IREE_VM_ABI_VLA_STACK_DEREF_OR_NULL( \ + args, vla_count, vla_field, ref_type, max_count, out_count, out_ptrs) \ + *(out_count) = (args)->vla_count; \ + if (IREE_UNLIKELY((args)->vla_count > (max_count))) { \ + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, \ + "count %u of " #ref_type " > %u", \ + (args)->vla_count, (uint32_t)(max_count)); \ + } \ + *(out_ptrs) = \ + (ref_type##_t**)iree_alloca((args)->vla_count * sizeof(ref_type##_t*)); \ + for (iree_host_size_t i = 0; i < (args)->vla_count; ++i) { \ + IREE_RETURN_IF_ERROR(ref_type##_check_deref_or_null( \ + (args)->vla_field[i].r0, &(*(out_ptrs))[i])); \ + } + #define IREE_VM_ABI_VLA_HEAP_DEREF(args, vla_count, vla_field, ref_type, \ host_allocator, out_count, out_ptrs) \ *(out_count) = (args)->vla_count; \ @@ -340,6 +355,34 @@ IREE_VM_ABI_FIXED_STRUCT(rrrIii, { int32_t i5; }); +IREE_VM_ABI_FIXED_STRUCT(rIrriiiI, { + iree_vm_ref_t r0; + int64_t i1; + iree_vm_ref_t r2; + iree_vm_ref_t r3; + int32_t i4; + int32_t i5; + int32_t i6; + int64_t i7; +}); + +IREE_VM_ABI_FIXED_STRUCT(rIrrr, { + iree_vm_ref_t r0; + int64_t i1; + iree_vm_ref_t r2; + iree_vm_ref_t r3; + iree_vm_ref_t r4; +}); + +IREE_VM_ABI_VLA_STRUCT(rIrrCrD, a4_count, a4, { + iree_vm_ref_t r0; + int64_t i1; + iree_vm_ref_t r2; + iree_vm_ref_t r3; + iree_vm_size_t a4_count; + iree_vm_abi_r_t a4[0]; +}); + IREE_VM_ABI_VLA_STRUCT(rCiD, a1_count, a1, { iree_vm_ref_t r0; iree_vm_size_t a1_count; @@ -521,6 +564,9 @@ IREE_VM_ABI_DECLARE_SHIM(rrirCID, v); IREE_VM_ABI_DECLARE_SHIM(rrirI, v); IREE_VM_ABI_DECLARE_SHIM(rrIrII, v); IREE_VM_ABI_DECLARE_SHIM(rrrIii, v); +IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r); +IREE_VM_ABI_DECLARE_SHIM(rIrrr, v); +IREE_VM_ABI_DECLARE_SHIM(rIrrCrD, v); IREE_VM_ABI_DECLARE_SHIM(CrID, r); IREE_VM_ABI_DECLARE_SHIM(CrD, r); IREE_VM_ABI_DECLARE_SHIM(iCrD, i);