-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ObjectFifo] Create a new pass to split L2 buffers
-- This commit introduces a new pass `--iree-amdaie-split-buffers` to split L2 buffers for dealing with Matmul+Elementwise. -- It addresses sub-action 2 as well from #644 Signed-off-by: Abhishek Varma <abhvarma@amd.com>
- Loading branch information
1 parent
a6bdc0c
commit 19ea811
Showing
8 changed files
with
272 additions
and
1 deletion.
There are no files selected for viewing
156 changes: 156 additions & 0 deletions
156
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESplitBuffers.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
#include "aie/Dialect/AIE/IR/AIEDialect.h" | ||
#include "iree-amd-aie/IR/AMDAIEOps.h" | ||
#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h" | ||
#include "iree-amd-aie/Transforms/Passes.h" | ||
#include "iree-amd-aie/Transforms/Transforms.h" | ||
#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/SCF/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/Iterators.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" | ||
|
||
#define DEBUG_TYPE "iree-amdaie-split-buffers" | ||
|
||
namespace mlir::iree_compiler::AMDAIE { | ||
|
||
namespace { | ||
|
||
class AMDAIESplitBuffersPass | ||
: public impl::AMDAIESplitBuffersBase<AMDAIESplitBuffersPass> { | ||
public: | ||
using AMDAIESplitBuffersBase::AMDAIESplitBuffersBase; | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<AMDAIEDialect>(); | ||
} | ||
void runOnOperation() override; | ||
}; | ||
|
||
void AMDAIESplitBuffersPass::runOnOperation() { | ||
ModuleOp moduleOp = getOperation(); | ||
IRRewriter rewriter(moduleOp.getContext()); | ||
|
||
DenseMap<memref::AllocOp, memref::AllocOp> memrefToNew; | ||
SmallVector<AMDAIE::LogicalObjectFifoConsume> consumeOps; | ||
moduleOp.walk([&](AMDAIE::CoreOp coreOp) { | ||
AMDAIE::LogicalObjectFifoConsume candidateConsumeOp = nullptr; | ||
unsigned consumeOpCount = 0; | ||
coreOp.walk([&](AMDAIE::LogicalObjectFifoConsume consumeOp) { | ||
++consumeOpCount; | ||
if (consumeOpCount == 3) { | ||
candidateConsumeOp = consumeOp; | ||
return WalkResult::interrupt(); | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
if (!candidateConsumeOp) return WalkResult::skip(); | ||
consumeOps.push_back(candidateConsumeOp); | ||
return WalkResult::advance(); | ||
}); | ||
|
||
DenseSet<Operation *> toBeErased; | ||
for (AMDAIE::LogicalObjectFifoConsume candidateConsumeOp : consumeOps) { | ||
LogicalObjectFifoFromMemrefOp sourceObjectFifo = | ||
candidateConsumeOp.getDmaCpyNdOp().getSourceObjectFifo(); | ||
auto sourceAllocOp = | ||
sourceObjectFifo.getMemref().getDefiningOp<memref::AllocOp>(); | ||
uint64_t sourceMemrefSpace = sourceObjectFifo.getMemorySpaceAsUInt(); | ||
if (!sourceAllocOp || sourceMemrefSpace != 1) continue; | ||
// Should do similar checks for target. Will do. This is WIP. | ||
LogicalObjectFifoFromMemrefOp targetObjectFifo = | ||
candidateConsumeOp.getDmaCpyNdOp().getTargetObjectFifo(); | ||
auto targetAllocOp = | ||
targetObjectFifo.getMemref().getDefiningOp<memref::AllocOp>(); | ||
|
||
// Now we'll create a narrowed buffer. | ||
rewriter.setInsertionPoint(sourceAllocOp); | ||
auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType()); | ||
auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType()); | ||
MemRefType newAllocType = MemRefType::get( | ||
targetMemRefType.getNumElements(), targetMemRefType.getElementType(), | ||
MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace()); | ||
auto newAllocOp = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(), | ||
newAllocType); | ||
auto newDeallocOp = rewriter.create<memref::DeallocOp>( | ||
rewriter.getUnknownLoc(), newAllocOp); | ||
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back()); | ||
|
||
// Although we have the DmaCpyNd user above, the | ||
// logicalobjectfifo.from_memref is used in other DmaCpyNds as well for | ||
// other core ops. | ||
AMDAIE::DmaCpyNdOp l3ToL2DmaOp; | ||
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) { | ||
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp); | ||
dmaOp.getTargetObjectFifo() == sourceObjectFifo) { | ||
l3ToL2DmaOp = dmaOp; | ||
toBeErased.insert(dmaOp); | ||
break; | ||
} | ||
} | ||
toBeErased.insert(sourceAllocOp); | ||
toBeErased.insert(sourceObjectFifo); | ||
|
||
AMDAIE::DmaCpyNdOp l2ToL1DmaOp = candidateConsumeOp.getDmaCpyNdOp(); | ||
auto type = cast<MemRefType>(newAllocOp.getType()); | ||
SmallVector<Value> empty; | ||
rewriter.setInsertionPoint(l2ToL1DmaOp.getSourceObjectFifo()); | ||
auto source = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>( | ||
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type), | ||
newAllocOp.getResult(), sourceObjectFifo.getTiles()); | ||
|
||
rewriter.setInsertionPoint(l3ToL2DmaOp); | ||
rewriter.create<AMDAIE::DmaCpyNdOp>( | ||
l3ToL2DmaOp.getLoc(), source, l3ToL2DmaOp.getTargetMixedOffsets(), | ||
l3ToL2DmaOp.getTargetMixedSizes(), l3ToL2DmaOp.getTargetMixedStrides(), | ||
l3ToL2DmaOp.getSource(), l3ToL2DmaOp.getSourceMixedOffsets(), | ||
l3ToL2DmaOp.getSourceMixedSizes(), l3ToL2DmaOp.getSourceMixedStrides()); | ||
|
||
rewriter.setInsertionPoint(l2ToL1DmaOp); | ||
auto newL2ToL1DmaOp = rewriter.create<AMDAIE::DmaCpyNdOp>( | ||
l2ToL1DmaOp.getLoc(), l2ToL1DmaOp.getTarget(), | ||
l2ToL1DmaOp.getTargetMixedOffsets(), l2ToL1DmaOp.getTargetMixedSizes(), | ||
l2ToL1DmaOp.getTargetMixedStrides(), source, | ||
l2ToL1DmaOp.getSourceMixedOffsets(), l2ToL1DmaOp.getSourceMixedSizes(), | ||
l2ToL1DmaOp.getSourceMixedStrides()); | ||
rewriter.replaceOp(l2ToL1DmaOp, newL2ToL1DmaOp); | ||
// We have to discard non-zero offsets as subview has been replaced by a | ||
// dedicated allocated memref. | ||
SmallVector<int64_t> allocShape(type.getShape()); | ||
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Source>( | ||
rewriter, | ||
cast<AMDAIE::DoublyStridedOpInterface>(newL2ToL1DmaOp.getOperation()), | ||
allocShape); | ||
|
||
// Remove old dealloc. | ||
memref::DeallocOp oldDeallocOp; | ||
for (Operation *userOp : sourceAllocOp->getUsers()) { | ||
if (auto deallocUser = dyn_cast<memref::DeallocOp>(userOp)) { | ||
oldDeallocOp = deallocUser; | ||
} | ||
} | ||
if (oldDeallocOp) { | ||
rewriter.eraseOp(oldDeallocOp); | ||
} | ||
} | ||
|
||
for (Operation *op : toBeErased) { | ||
op->dropAllUses(); | ||
rewriter.eraseOp(op); | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> createAMDAIESplitBuffersPass() { | ||
return std::make_unique<AMDAIESplitBuffersPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::AMDAIE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/split_buffers.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// RUN: iree-opt --pass-pipeline="builtin.module(iree-amdaie-split-buffers,cse)" --split-input-file --verify-diagnostics %s | FileCheck %s | ||
|
||
// CHECK-LABEL: @split_l2_buffer | ||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index | ||
// CHECK-DAG: %[[L3_ALLOC:.*]] = memref.alloc() : memref<128x128xi32> | ||
// CHECK-DAG: %[[L2_ALLOC:.*]] = memref.alloc() : memref<1024xi32, 1 : i32> | ||
// CHECK-DAG: %[[L1_ALLOC:.*]] = memref.alloc() : memref<1x1x8x8x4x4xi32, 2 : i32> | ||
// CHECK: %[[TILE:.*]] = amdaie.tile(%[[C1]], %[[C3]]) | ||
// CHECK: %[[L2_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L2_ALLOC]], {%[[TILE]]} : | ||
// CHECK-SAME: memref<1024xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>> | ||
// CHECK: %[[L3_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L3_ALLOC]], {%[[TILE]]} : | ||
// CHECK-SAME: memref<128x128xi32> -> !amdaie.logicalobjectfifo<memref<128x128xi32>> | ||
// CHECK: scf.forall | ||
// CHECK: %[[DMA_CPY_ND_L3_TO_L2:.*]] = amdaie.dma_cpy_nd(%[[L2_OBJECTFIFO]] | ||
// CHECK-SAME: %[[L3_OBJECTFIFO]] | ||
// CHECK: amdaie.logicalobjectfifo.from_memref | ||
// CHECK: amdaie.logicalobjectfifo.from_memref | ||
// CHECK: amdaie.dma_cpy_nd | ||
// CHECK: amdaie.dma_cpy_nd | ||
// CHECK: %[[L1_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L1_ALLOC]] | ||
// CHECK: %[[DMA_CPY_ND_L2_TO_L1:.*]] = amdaie.dma_cpy_nd(%[[L1_OBJECTFIFO]] | ||
// CHECK-SAME: %[[L2_OBJECTFIFO]] | ||
// CHECK: amdaie.core(%[[TILE]]) { | ||
// CHECK: amdaie.logicalobjectfifo.consume | ||
// CHECK: amdaie.logicalobjectfifo.consume | ||
// CHECK: linalg.generic | ||
// CHECK: amdaie.logicalobjectfifo.consume(%[[DMA_CPY_ND_L2_TO_L1]]) | ||
// CHECK: } | ||
// CHECK: memref.dealloc %[[L2_ALLOC]] : memref<1024xi32, 1 : i32> | ||
#map = affine_map<(d0) -> (d0 * 64)> | ||
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)> | ||
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)> | ||
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> | ||
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> | ||
module { | ||
func.func @split_l2_buffer(%arg0: !amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>>, %arg1: !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>, %arg2: !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>, %arg3: !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>) { | ||
%c3 = arith.constant 3 : index | ||
%c16 = arith.constant 16 : index | ||
%c8 = arith.constant 8 : index | ||
%c4 = arith.constant 4 : index | ||
%c128 = arith.constant 128 : index | ||
%c2048 = arith.constant 2048 : index | ||
%c256 = arith.constant 256 : index | ||
%c1024 = arith.constant 1024 : index | ||
%c4096 = arith.constant 4096 : index | ||
%c32 = arith.constant 32 : index | ||
%c2 = arith.constant 2 : index | ||
%c1 = arith.constant 1 : index | ||
%c0 = arith.constant 0 : index | ||
%alloc = memref.alloc() : memref<2x1x32x32xi32, 1 : i32> | ||
%alloc_0 = memref.alloc() : memref<1x2x32x32xi32, 1 : i32> | ||
%alloc_1 = memref.alloc() : memref<2x2x32x32xi32, 1 : i32> | ||
%alloc_2 = memref.alloc() : memref<128x128xi32> | ||
%alloc_3 = memref.alloc() : memref<1x1x8x8x4x4xi32, 2 : i32> | ||
%tile = amdaie.tile(%c1, %c3) | ||
%0 = amdaie.logicalobjectfifo.from_memref %alloc_1, {%tile} : memref<2x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>> | ||
%1 = amdaie.logicalobjectfifo.from_memref %alloc_2, {%tile} : memref<128x128xi32> -> !amdaie.logicalobjectfifo<memref<128x128xi32>> | ||
scf.forall (%arg4, %arg5) in (2, 2) { | ||
%2 = affine.apply #map(%arg5) | ||
%3 = affine.apply #map(%arg4) | ||
%4 = amdaie.dma_cpy_nd(%0[%c0, %c0, %c0, %c0] [%c2, %c2, %c32, %c32] [%c2048, %c1024, %c32, %c1], %1[%c0, %c0, %3, %2] [%c2, %c2, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (!amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<128x128xi32>>) | ||
%tile_4 = amdaie.tile(%c1, %c3) | ||
%5 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile} : memref<2x1x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2x1x32x32xi32, 1 : i32>> | ||
%6 = amdaie.logicalobjectfifo.from_memref %alloc_0, {%tile} : memref<1x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>> | ||
%7 = amdaie.dma_cpy_nd(%arg0[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c1024, %c1024, %c256, %c32, %c8, %c1], %5[%c1, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c1024, %c1024, %c8, %c128, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<2x1x32x32xi32, 1 : i32>>) | ||
%8 = amdaie.dma_cpy_nd(%arg1[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c1024, %c1024, %c128, %c32, %c4, %c1], %6[%c0, %c1, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c1024, %c4, %c256, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>>) | ||
%9 = amdaie.logicalobjectfifo.from_memref %alloc_3, {%tile} : memref<1x1x8x8x4x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> | ||
%10 = amdaie.dma_cpy_nd(%9[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c8, %c4, %c4] [%c1024, %c1024, %c128, %c16, %c4, %c1], %0[%c1, %c1, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c8, %c4, %c4] [%c2048, %c1024, %c4, %c128, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>) | ||
%11 = amdaie.dma_cpy_nd(%arg3[%c1, %c1, %c0, %c0] [%c1, %c1, %c32, %c32] [%c2048, %c1024, %c32, %c1], %arg2[%c0, %c0, %c0, %c0] [%c8, %c4, %c8, %c4] [%c16, %c4, %c128, %c1]) : (!amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>) | ||
%12 = amdaie.core(%tile_4) { | ||
amdaie.logicalobjectfifo.consume(%7) | ||
amdaie.logicalobjectfifo.consume(%8) | ||
%13 = amdaie.logicalobjectfifo.access(%arg0, Read) : !amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>> -> memref<1x1x4x8x4x8xi32, 2 : i32> | ||
%14 = amdaie.logicalobjectfifo.access(%arg1, Read) : !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>> -> memref<1x1x8x4x8x4xi32, 2 : i32> | ||
%15 = amdaie.logicalobjectfifo.access(%arg2, None) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32> | ||
linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%13, %14 : memref<1x1x4x8x4x8xi32, 2 : i32>, memref<1x1x8x4x8x4xi32, 2 : i32>) outs(%15 : memref<1x1x8x8x4x4xi32, 2 : i32>) { | ||
^bb0(%in: i32, %in_5: i32, %out: i32): | ||
%18 = arith.muli %in, %in_5 : i32 | ||
%19 = arith.addi %out, %18 : i32 | ||
linalg.yield %19 : i32 | ||
} | ||
amdaie.logicalobjectfifo.consume(%10) | ||
%16 = amdaie.logicalobjectfifo.access(%arg2, Read) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32> | ||
%17 = amdaie.logicalobjectfifo.access(%arg2, Write) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32> | ||
linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%15, %16 : memref<1x1x8x8x4x4xi32, 2 : i32>, memref<1x1x8x8x4x4xi32, 2 : i32>) outs(%17 : memref<1x1x8x8x4x4xi32, 2 : i32>) { | ||
^bb0(%in: i32, %in_5: i32, %out: i32): | ||
%18 = arith.addi %in, %in_5 : i32 | ||
linalg.yield %18 : i32 | ||
} | ||
amdaie.logicalobjectfifo.produce(%11) | ||
amdaie.end | ||
} | ||
} {mapping = [#gpu.block<y>, #gpu.block<x>]} | ||
memref.dealloc %alloc : memref<2x1x32x32xi32, 1 : i32> | ||
memref.dealloc %alloc_3 : memref<1x1x8x8x4x4xi32, 2 : i32> | ||
memref.dealloc %alloc_0 : memref<1x2x32x32xi32, 1 : i32> | ||
memref.dealloc %alloc_1 : memref<2x2x32x32xi32, 1 : i32> | ||
memref.dealloc %alloc_2 : memref<128x128xi32> | ||
return | ||
} | ||
} |