Skip to content

Commit

Permalink
add fix for operand of AccessOp, general dump of required changes
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Sep 25, 2024
1 parent d401b1e commit b86e1ec
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 51 deletions.
64 changes: 32 additions & 32 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,16 +605,16 @@ def __init__(self):
def run(self, config):

testSet = [
{
"conv_type": "conv_2d_nhwc_hwcf",
"N": 2,
"IH": 14,
"IC": 32,
"OC": 64,
"KH": 3,
"input_element_type": "i32",
"output_element_type": "i32",
},
# {
# "conv_type": "conv_2d_nhwc_hwcf",
# "N": 2,
# "IH": 14,
# "IC": 32,
# "OC": 64,
# "KH": 3,
# "input_element_type": "i32",
# "output_element_type": "i32",
# },
{
"conv_type": "conv_2d_nhwc_hwcf",
"N": 2,
Expand All @@ -625,39 +625,39 @@ def run(self, config):
"input_element_type": "bf16",
"output_element_type": "f32",
},
{
"conv_type": "conv_2d_nhwc_hwcf",
"N": 2,
"IH": 14,
"IC": 32,
"OC": 64,
"KH": 3,
"input_element_type": "i8",
"output_element_type": "i32",
},
{
"conv_type": "depthwise_conv_2d_nhwc_hwc",
"N": 1,
"IH": 14,
"IC": 64,
"KH": 3,
"input_element_type": "i32",
"output_element_type": "i32",
},
# {
# "conv_type": "conv_2d_nhwc_hwcf",
# "N": 2,
# "IH": 14,
# "IC": 32,
# "OC": 64,
# "KH": 3,
# "input_element_type": "i8",
# "output_element_type": "i32",
# },
# {
# "conv_type": "depthwise_conv_2d_nhwc_hwc",
# "N": 1,
# "IH": 14,
# "IC": 64,
# "KH": 3,
# "input_element_type": "i32",
# "output_element_type": "i32",
# },
]

output_dir = config.output_dir
test_name = output_dir / "test_from_template.mlir"
for testMap in testSet:
convGen = ConvolutionMlirGenerator(**testMap)
convGen.write_to_file(test_name)
n_conv_repeats = 4
n_conv_repeats = 1

aie_vs_llvm_cpu(
config,
test_name,
tile_pipeline="conv-decompose",
lower_to_aie_pipeline="air",
lower_to_aie_pipeline="objectFifo",
n_repeats=n_conv_repeats,
)

Expand All @@ -677,7 +677,7 @@ def run(self, config):
config,
test_files_dir / f"{name}.mlir",
tile_pipeline="conv-decompose",
lower_to_aie_pipeline="air",
lower_to_aie_pipeline="objectFifo",
n_repeats=n_conv_repeats,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-amdaie-convert-to-dma"

Expand Down Expand Up @@ -209,30 +212,30 @@ SmallVector<OpFoldResult> getIndexOpFoldResultSum(OpBuilder &builder,
};

for (uint64_t i = 0; i < lhs.size(); ++i) {
IntegerAttr aAttr;
if (auto aAttr_ = dyn_cast<Attribute>(lhs[i])) {
aAttr = dyn_cast<IntegerAttr>(aAttr_);
assert(aAttr && "Expected an IntegerAttr");
IntegerAttr lhsAttr;
if (auto lhsAttr_ = dyn_cast<Attribute>(lhs[i])) {
lhsAttr = dyn_cast<IntegerAttr>(lhsAttr_);
assert(lhsAttr && "Expected an IntegerAttr");
}

IntegerAttr bAttr;
if (auto bAttr_ = dyn_cast<Attribute>(rhs[i])) {
bAttr = dyn_cast<IntegerAttr>(bAttr_);
assert(bAttr && "Expected an IntegerAttr");
IntegerAttr rhsAttr;
if (auto rhsAttr_ = dyn_cast<Attribute>(rhs[i])) {
rhsAttr = dyn_cast<IntegerAttr>(rhsAttr_);
assert(rhsAttr && "Expected an IntegerAttr");
}

if (aAttr && bAttr) {
sum.push_back(getAsIndexOpFoldResult(builder.getContext(),
aAttr.getInt() + bAttr.getInt()));
} else if (!aAttr && !bAttr) {
if (lhsAttr && rhsAttr) {
sum.push_back(getAsIndexOpFoldResult(
builder.getContext(), lhsAttr.getInt() + rhsAttr.getInt()));
} else if (!lhsAttr && !rhsAttr) {
sum.push_back(builder
.create<arith::AddIOp>(loc, cast<Value>(lhs[i]),
cast<Value>(rhs[i]))
.getResult());
} else if (!aAttr && bAttr) {
sum.push_back(add(cast<Value>(lhs[i]), bAttr));
} else if (aAttr && !bAttr) {
sum.push_back(add(cast<Value>(rhs[i]), aAttr));
} else if (!lhsAttr && rhsAttr) {
sum.push_back(add(cast<Value>(lhs[i]), rhsAttr));
} else if (lhsAttr && !rhsAttr) {
sum.push_back(add(cast<Value>(rhs[i]), lhsAttr));
} else {
assert(false && "unreachable");
}
Expand Down Expand Up @@ -297,7 +300,6 @@ OpFoldResult getLinearCombination(OpBuilder &builder, Location loc,
return combination;
}


/// Update the offsets, sizes, and strides from a collapse shape operation.
LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp,
SmallVector<OpFoldResult> &offsets,
Expand Down Expand Up @@ -399,7 +401,6 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp,
return success();
}


/// Update the offsets, sizes, and strides from a subview operation.
LogicalResult updateFromSubView(memref::SubViewOp subviewOp,
SmallVector<OpFoldResult> &offsets,
Expand Down Expand Up @@ -671,6 +672,10 @@ void AMDAIEConvertToDmaPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);





// Convert all linalg.copy to iree_linalg_ext.pack/unpack ops. We then
// bootstrap the work done for lowering the pack/unpack op to dmas as the next
// step. This is easy to implement, but not the most direct lowering, so
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@ LogicalResult insertLogicalObjectFifoAccess(ModuleOp moduleOp) {
opToInsertRewriterPoint = opToInsertRewriterPoint->getParentOp();
}
for (auto &&[idx, operand] : llvm::enumerate(op->getOpOperands())) {
llvm::errs() << "Considering operand #" << idx << " of op: " << *op
<< "\n";
if (memrefToLogicalObjectFifoAccess.contains(operand.get())) {
op->setOperand(idx, memrefToLogicalObjectFifoAccess[operand.get()]);
} else if (memrefToLogicalObjectFifo.contains(operand.get())) {
Expand All @@ -672,12 +674,29 @@ LogicalResult insertLogicalObjectFifoAccess(ModuleOp moduleOp) {
op->setOperand(idx, accessOp);
} else if (auto type =
llvm::dyn_cast<MemRefType>(operand.get().getType())) {
Value memref = operand.get();
Operation *memrefAlloc = operand.get().getDefiningOp();
// Traverse back from memref until you find the AllocOp:
while (memrefAlloc && !isa<memref::AllocOp>(memrefAlloc)) {
memrefAlloc = memrefAlloc->getOperand(0).getDefiningOp();
}

if (!memrefAlloc) {
op->emitOpError("could not find AllocOp root");
return WalkResult::interrupt();
}

Value memref = memrefAlloc->getResult(0);

rewriter.setInsertionPoint(coreOp);

llvm::errs() << "Creating logical ObjectFifo from memref: " << *op
<< "\n";
auto logicalObjectFifo =
rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type),
memref);

llvm::errs() << "Created " << *logicalObjectFifo << "\n";
rewriter.setInsertionPoint(opToInsertRewriterPoint);

AMDAIE::LogicalObjectFifoAccessOp accessOp;
Expand Down
70 changes: 70 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -572,9 +573,78 @@ void buildAMDAIETransformPassPipeline(OpPassManager &variantPassManager,
});
}

// On the fly pass, use PassWrapper. Pass name CanonicalizeMemref:
namespace {
struct CanonicalizeMemref
: public PassWrapper<CanonicalizeMemref, OperationPass<ModuleOp>> {
void runOnOperation() override {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);

// TODO: a pass 'canonicalize for dma conversion'.
// Convert all memref.subviews which are rank-reducing into full rank
// subviews followed by memref.collapse_shape ops.
WalkResult foo = getOperation()->walk(
[&rewriter](memref::SubViewOp subViewOp) -> WalkResult {
llvm::SmallBitVector droppedDims = subViewOp.getDroppedDims();
// The number of droppoed dims:
if (droppedDims.count() == 0) return WalkResult::advance();

rewriter.setInsertionPointAfter(subViewOp);
auto newSubViewOp = rewriter.create<memref::SubViewOp>(
subViewOp.getLoc(), subViewOp.getSource(),
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
subViewOp.getMixedStrides());
std::optional<SmallVector<ReassociationIndices>> reass =
mlir::getReassociationIndicesForCollapse(
newSubViewOp.getType().getShape(),
subViewOp.getType().getShape());
if (!reass.has_value()) {
subViewOp.emitOpError(
"could not get reassociation indices for collapse");
return WalkResult::interrupt();
}

rewriter.replaceOpWithNewOp<memref::CollapseShapeOp>(
subViewOp, newSubViewOp.getResult(), reass.value());
return WalkResult::advance();
});

if (foo.wasInterrupted()) {
signalPassFailure();
}

WalkResult bar = getOperation()->walk(
[&rewriter](memref::ExpandShapeOp expandOp) -> WalkResult {
auto def = expandOp.getSrc().getDefiningOp();
if (!def) return WalkResult::advance();
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(def)) {
auto collapseSource = collapseOp.getSrc();
if (collapseSource.getType().getShape() ==
expandOp.getType().getShape()) {
rewriter.replaceOp(expandOp, collapseSource);
}
}

return WalkResult::advance();
});

if (bar.wasInterrupted()) {
signalPassFailure();
}
}

public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
CanonicalizeMemref);
};
} // namespace

void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager) {
passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass());
passManager.addPass(memref::createFoldMemRefAliasOpsPass());
passManager.addPass(std::make_unique<CanonicalizeMemref>());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIEConvertToDmaPass());

passManager.addPass(createAMDAIENormalizeLoopBoundsPass());
Expand Down

0 comments on commit b86e1ec

Please sign in to comment.