Skip to content

Commit

Permalink
[ObjectFifo] Create a new pass to split L2 buffers
Browse files Browse the repository at this point in the history
-- This commit introduces a new pass `--iree-amdaie-split-buffers`
   to split L2 buffers for dealing with Matmul+Elementwise.
-- It addresses sub-action 2 as well from #644

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
  • Loading branch information
Abhishek-Varma committed Aug 12, 2024
1 parent 883ee07 commit b2d08c7
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"

#define DEBUG_TYPE "iree-amdaie-split-buffers"

namespace mlir::iree_compiler::AMDAIE {

namespace {

class AMDAIESplitBuffersPass
: public impl::AMDAIESplitBuffersBase<AMDAIESplitBuffersPass> {
public:
using AMDAIESplitBuffersBase::AMDAIESplitBuffersBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}
void runOnOperation() override;
};

void AMDAIESplitBuffersPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
IRRewriter rewriter(moduleOp.getContext());

SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if
// applicable) from them.
// TODO: We will generalize this later.
moduleOp.walk([&](AMDAIE::CoreOp coreOp) {
SmallVector<Value> inputDmas = coreOp.getInputDmas();
if (inputDmas.size() < 3) return WalkResult::skip();
l2ToL1DmaOps.push_back(inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>());
return WalkResult::advance();
});

DenseSet<Operation *> toBeErased;
for (AMDAIE::DmaCpyNdOp l2ToL1DmaOp : l2ToL1DmaOps) {
LogicalObjectFifoFromMemrefOp sourceObjectFifo =
l2ToL1DmaOp.getSourceObjectFifo();
auto sourceAllocOp =
sourceObjectFifo.getMemref().getDefiningOp<memref::AllocOp>();
uint64_t sourceMemrefSpace = sourceObjectFifo.getMemorySpaceAsUInt();
if (!sourceAllocOp || sourceMemrefSpace != 1) continue;
LogicalObjectFifoFromMemrefOp targetObjectFifo =
l2ToL1DmaOp.getTargetObjectFifo();
Value targetAllocOp = targetObjectFifo.getMemref();

// Now we'll create a narrowed L2 buffer.
rewriter.setInsertionPoint(sourceAllocOp);
auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType());
auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType());
MemRefType newAllocType = MemRefType::get(
targetMemRefType.getNumElements(), targetMemRefType.getElementType(),
MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace());
auto newAllocOp = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(),
newAllocType);
auto newDeallocOp = rewriter.create<memref::DeallocOp>(
rewriter.getUnknownLoc(), newAllocOp);
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back());

// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOp = dmaOp;
toBeErased.insert(dmaOp);
break;
}
}
toBeErased.insert(sourceAllocOp);
toBeErased.insert(sourceObjectFifo);

auto type = cast<MemRefType>(newAllocOp.getType());
// Create new logicalobjectfifo.from_memref for the newly created L2 buffer.
rewriter.setInsertionPoint(l2ToL1DmaOp.getSourceObjectFifo());
auto source = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type),
newAllocOp.getResult(), sourceObjectFifo.getTiles());

// Create new L3 -> L2 Dma Op.
rewriter.setInsertionPoint(l3ToL2DmaOp);
rewriter.create<AMDAIE::DmaCpyNdOp>(
l3ToL2DmaOp.getLoc(), source, l3ToL2DmaOp.getTargetMixedOffsets(),
l3ToL2DmaOp.getTargetMixedSizes(), l3ToL2DmaOp.getTargetMixedStrides(),
l3ToL2DmaOp.getSource(), l3ToL2DmaOp.getSourceMixedOffsets(),
l3ToL2DmaOp.getSourceMixedSizes(), l3ToL2DmaOp.getSourceMixedStrides());

// Create new L2 -> L1 Input DmaOp.
rewriter.setInsertionPoint(l2ToL1DmaOp);
auto newL2ToL1DmaOp = rewriter.create<AMDAIE::DmaCpyNdOp>(
l2ToL1DmaOp.getLoc(), l2ToL1DmaOp.getTarget(),
l2ToL1DmaOp.getTargetMixedOffsets(), l2ToL1DmaOp.getTargetMixedSizes(),
l2ToL1DmaOp.getTargetMixedStrides(), source,
l2ToL1DmaOp.getSourceMixedOffsets(), l2ToL1DmaOp.getSourceMixedSizes(),
l2ToL1DmaOp.getSourceMixedStrides());
rewriter.replaceOp(l2ToL1DmaOp, newL2ToL1DmaOp);
// We have to discard non-zero offsets as subview has been replaced by a
// dedicated allocated memref.
SmallVector<int64_t> allocShape(type.getShape());
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Source>(
rewriter,
cast<AMDAIE::DoublyStridedOpInterface>(newL2ToL1DmaOp.getOperation()),
allocShape);

// Remove old dealloc.
memref::DeallocOp oldDeallocOp;
for (Operation *userOp : sourceAllocOp->getUsers()) {
if (auto deallocUser = dyn_cast<memref::DeallocOp>(userOp)) {
oldDeallocOp = deallocUser;
}
}
if (oldDeallocOp) {
rewriter.eraseOp(oldDeallocOp);
}
}

for (Operation *op : toBeErased) {
op->dropAllUses();
rewriter.eraseOp(op);
}
}

} // namespace

std::unique_ptr<Pass> createAMDAIESplitBuffersPass() {
return std::make_unique<AMDAIESplitBuffersPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ iree_cc_library(
"AMDAIEPad.cpp"
"AMDAIEPeelForLoop.cpp"
"AMDAIEPropagateDataLayout.cpp"
"AMDAIESplitBuffers.cpp"
"AMDAIETile.cpp"
"AMDAIETileAndFuse.cpp"
"AMDAIEUtils.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEPACKANDTRANSPOSE
#define GEN_PASS_DEF_AMDAIEPACKTODMA
#define GEN_PASS_DEF_AMDAIEPAD
#define GEN_PASS_DEF_AMDAIEVECTORIZATION
#define GEN_PASS_DEF_AMDAIEPEELFORLOOP
#define GEN_PASS_DEF_AMDAIEPROPAGATEDATALAYOUT
#define GEN_PASS_DEF_AMDAIESPLITBUFFERS
#define GEN_PASS_DEF_AMDAIETILE
#define GEN_PASS_DEF_AMDAIETILEANDFUSE
#define GEN_PASS_DEF_AMDAIEVECTORIZATION
#include "iree-amd-aie/Transforms/Passes.h.inc"

} // namespace mlir::iree_compiler::AMDAIE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ std::unique_ptr<Pass> createAMDAIEPadPass(AMDAIEPadOptions options = {});
std::unique_ptr<Pass> createAMDAIEPeelForLoopPass(
AMDAIEPeelForLoopOptions options = {});

/// Create a pass to split buffers.
std::unique_ptr<Pass> createAMDAIESplitBuffersPass();

/// Create pass to tile TilingInterface operations.
std::unique_ptr<Pass> createAMDAIETilePass(AMDAIETileOptions options = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,12 @@ def AMDAIEPropagateDataLayout :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEPropagateDataLayoutPass()";
}

def AMDAIESplitBuffers :
Pass<"iree-amdaie-split-buffers", "ModuleOp"> {
let summary = "Split buffers.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIESplitBuffersPass()";
}

def AMDAIETile :
InterfacePass<"iree-amdaie-tile", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile TilingInterface operations.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ iree_lit_test_suite(
"pad.mlir"
"peel_for_loop.mlir"
"propagate_data_layout.mlir"
"split_buffers.mlir"
"tile_and_fuse_using_scf_for.mlir"
"tile_and_fuse_using_scf_forall.mlir"
"tile_copy_using_scf_for.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-amdaie-split-buffers,cse)" --split-input-file --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @split_l2_buffer
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[L3_ALLOC:.*]] = memref.alloc() : memref<128x128xi32>
// CHECK-DAG: %[[L2_ALLOC:.*]] = memref.alloc() : memref<1024xi32, 1 : i32>
// CHECK-DAG: %[[L1_ALLOC:.*]] = memref.alloc() : memref<1x1x8x8x4x4xi32, 2 : i32>
// CHECK: %[[TILE:.*]] = amdaie.tile(%[[C1]], %[[C3]])
// CHECK: %[[L2_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L2_ALLOC]], {%[[TILE]]} :
// CHECK-SAME: memref<1024xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>>
// CHECK: %[[L3_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L3_ALLOC]], {%[[TILE]]} :
// CHECK-SAME: memref<128x128xi32> -> !amdaie.logicalobjectfifo<memref<128x128xi32>>
// CHECK: scf.forall
// CHECK: %[[DMA_CPY_ND_L3_TO_L2:.*]] = amdaie.dma_cpy_nd(%[[L2_OBJECTFIFO]]
// CHECK-SAME: %[[L3_OBJECTFIFO]]
// CHECK: amdaie.logicalobjectfifo.from_memref
// CHECK: amdaie.logicalobjectfifo.from_memref
// CHECK: amdaie.dma_cpy_nd
// CHECK: amdaie.dma_cpy_nd
// CHECK: %[[L1_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L1_ALLOC]]
// CHECK: %[[DMA_CPY_ND_L2_TO_L1:.*]] = amdaie.dma_cpy_nd(%[[L1_OBJECTFIFO]]
// CHECK-SAME: %[[L2_OBJECTFIFO]]
// CHECK: amdaie.core(%[[TILE]], in : [%{{.*}}, %{{.*}}, %[[DMA_CPY_ND_L2_TO_L1]]], out :
// CHECK: linalg.generic
// CHECK: }
// CHECK: memref.dealloc %[[L2_ALLOC]] : memref<1024xi32, 1 : i32>
#map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
module {
func.func @split_l2_buffer(%arg0: !amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>>, %arg1: !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>, %arg2: !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>, %arg3: !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>) {
%c3 = arith.constant 3 : index
%c16 = arith.constant 16 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c2048 = arith.constant 2048 : index
%c256 = arith.constant 256 : index
%c1024 = arith.constant 1024 : index
%c4096 = arith.constant 4096 : index
%c32 = arith.constant 32 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<2x1x32x32xi32, 1 : i32>
%alloc_0 = memref.alloc() : memref<1x2x32x32xi32, 1 : i32>
%alloc_1 = memref.alloc() : memref<2x2x32x32xi32, 1 : i32>
%alloc_2 = memref.alloc() : memref<128x128xi32>
%alloc_3 = memref.alloc() : memref<1x1x8x8x4x4xi32, 2 : i32>
%tile = amdaie.tile(%c1, %c3)
%0 = amdaie.logicalobjectfifo.from_memref %alloc_1, {%tile} : memref<2x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>
%1 = amdaie.logicalobjectfifo.from_memref %alloc_2, {%tile} : memref<128x128xi32> -> !amdaie.logicalobjectfifo<memref<128x128xi32>>
scf.forall (%arg4, %arg5) in (2, 2) {
%2 = affine.apply #map(%arg5)
%3 = affine.apply #map(%arg4)
%4 = amdaie.dma_cpy_nd(%0[%c0, %c0, %c0, %c0] [%c2, %c2, %c32, %c32] [%c2048, %c1024, %c32, %c1], %1[%c0, %c0, %3, %2] [%c2, %c2, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (!amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<128x128xi32>>)
%tile_4 = amdaie.tile(%c1, %c3)
%5 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile} : memref<2x1x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2x1x32x32xi32, 1 : i32>>
%6 = amdaie.logicalobjectfifo.from_memref %alloc_0, {%tile} : memref<1x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>>
%7 = amdaie.dma_cpy_nd(%arg0[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c1024, %c1024, %c256, %c32, %c8, %c1], %5[%c1, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c1024, %c1024, %c8, %c128, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<2x1x32x32xi32, 1 : i32>>)
%8 = amdaie.dma_cpy_nd(%arg1[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c1024, %c1024, %c128, %c32, %c4, %c1], %6[%c0, %c1, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c1024, %c4, %c256, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>>)
%9 = amdaie.logicalobjectfifo.from_memref %alloc_3, {%tile} : memref<1x1x8x8x4x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>
%10 = amdaie.dma_cpy_nd(%9[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c8, %c4, %c4] [%c1024, %c1024, %c128, %c16, %c4, %c1], %0[%c1, %c1, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c8, %c4, %c4] [%c2048, %c1024, %c4, %c128, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>)
%11 = amdaie.dma_cpy_nd(%arg3[%c1, %c1, %c0, %c0] [%c1, %c1, %c32, %c32] [%c2048, %c1024, %c32, %c1], %arg2[%c0, %c0, %c0, %c0] [%c8, %c4, %c8, %c4] [%c16, %c4, %c128, %c1]) : (!amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>)
%12 = amdaie.core(%tile_4, in : [%7, %8, %10], out : [%11]) {
%13 = amdaie.logicalobjectfifo.access(%arg0, Read) : !amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>> -> memref<1x1x4x8x4x8xi32, 2 : i32>
%14 = amdaie.logicalobjectfifo.access(%arg1, Read) : !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>> -> memref<1x1x8x4x8x4xi32, 2 : i32>
%15 = amdaie.logicalobjectfifo.access(%arg2, None) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32>
linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%13, %14 : memref<1x1x4x8x4x8xi32, 2 : i32>, memref<1x1x8x4x8x4xi32, 2 : i32>) outs(%15 : memref<1x1x8x8x4x4xi32, 2 : i32>) {
^bb0(%in: i32, %in_5: i32, %out: i32):
%18 = arith.muli %in, %in_5 : i32
%19 = arith.addi %out, %18 : i32
linalg.yield %19 : i32
}
%16 = amdaie.logicalobjectfifo.access(%arg2, Read) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32>
%17 = amdaie.logicalobjectfifo.access(%arg2, Write) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32>
linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%15, %16 : memref<1x1x8x8x4x4xi32, 2 : i32>, memref<1x1x8x8x4x4xi32, 2 : i32>) outs(%17 : memref<1x1x8x8x4x4xi32, 2 : i32>) {
^bb0(%in: i32, %in_5: i32, %out: i32):
%18 = arith.addi %in, %in_5 : i32
linalg.yield %18 : i32
}
amdaie.end
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc : memref<2x1x32x32xi32, 1 : i32>
memref.dealloc %alloc_3 : memref<1x1x8x8x4x4xi32, 2 : i32>
memref.dealloc %alloc_0 : memref<1x2x32x32xi32, 1 : i32>
memref.dealloc %alloc_1 : memref<2x2x32x32xi32, 1 : i32>
memref.dealloc %alloc_2 : memref<128x128xi32>
return
}
}

0 comments on commit b2d08c7

Please sign in to comment.