From 5ddbfd6025ac81dd78d0ac6fa5c8bf5941a2da1b Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 7 Oct 2024 10:55:14 -0700 Subject: [PATCH 1/4] resquash --- build_tools/ci/cpu_comparison/run.py | 4 +- .../iree-amd-aie/Test/samples/CMakeLists.txt | 2 +- ...e.mlir => conv2d_nhwc_objectfifo_e2e.mlir} | 2 +- .../Transforms/AMDAIEPackAndTranspose.cpp | 20 ++- .../Transforms/KernelDispatch.cpp | 131 +++++++++++++----- .../iree-amd-aie/Transforms/Passes.cpp | 65 ++++----- .../test/lowering_strategy_conv.mlir | 47 +++---- 7 files changed, 164 insertions(+), 107 deletions(-) rename compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/{conv2d_nhwc_air_e2e.mlir => conv2d_nhwc_objectfifo_e2e.mlir} (94%) diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index ffa148c38..faf52b6eb 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -680,7 +680,7 @@ def run(self, config): config, test_name, tile_pipeline="conv-decompose", - lower_to_aie_pipeline="air", + lower_to_aie_pipeline="objectFifo", n_repeats=n_conv_repeats, ) @@ -700,7 +700,7 @@ def run(self, config): config, test_files_dir / f"{name}.mlir", tile_pipeline="conv-decompose", - lower_to_aie_pipeline="air", + lower_to_aie_pipeline="objectFifo", n_repeats=n_conv_repeats, ) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt index 618409664..0cf6bc133 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/CMakeLists.txt @@ -8,7 +8,7 @@ iree_lit_test_suite( NAME lit SRCS - "conv2d_nhwc_air_e2e.mlir" + "conv2d_nhwc_objectfifo_e2e.mlir" "matmul_elementwise_pack_peel_air_e2e.mlir" "matmul_pack_peel_air_e2e.mlir" "matmul_pack_peel_objectfifo.mlir" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_air_e2e.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_objectfifo_e2e.mlir similarity index 94% rename from compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_air_e2e.mlir rename to compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_objectfifo_e2e.mlir index 2b005150a..171667038 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_air_e2e.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Test/samples/conv2d_nhwc_objectfifo_e2e.mlir @@ -1,4 +1,4 @@ -// RUN: iree-compile --iree-hal-target-backends=amd-aie --compile-to=executable-sources %s | iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-translate-target-executable-variants{target=amd-aie})))" --iree-amdaie-tile-pipeline=conv-decompose --iree-amdaie-lower-to-aie-pipeline=air --split-input-file | FileCheck %s +// RUN: iree-compile --iree-hal-target-backends=amd-aie --compile-to=executable-sources %s | iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-translate-target-executable-variants{target=amd-aie})))" --iree-amdaie-tile-pipeline=conv-decompose --iree-amdaie-lower-to-aie-pipeline=objectFifo --split-input-file | FileCheck %s func.func @conv_2d_nhwc_hwcf(%arg0: tensor<2x14x14x32xi32>, %arg1: tensor<3x3x32x64xi32>) -> tensor<2x12x12x64xi32> { %cst = arith.constant 0 : i32 diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp index 8f846110d..0654b9dd1 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp @@ -7,8 +7,11 @@ #include "iree-amd-aie/IR/AMDAIEAttrs.h" #include "iree-amd-aie/Transforms/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" #define DEBUG_TYPE "iree-amdaie-pack-and-transpose" @@ -20,7 +23,6 @@ namespace { static FailureOr applyPackOnLinalgOp( RewriterBase &rewriter, linalg::LinalgOp op, SmallVector packedSizes) { - // Fail on mismatched number of pack sizes. if (packedSizes.size() != op.getNumLoops()) { op->emitOpError( "requires number of packed sizes match the number of loops (") @@ -29,12 +31,14 @@ static FailureOr applyPackOnLinalgOp( } rewriter.setInsertionPoint(op); - FailureOr packResult = + FailureOr maybePackResult = linalg::pack(rewriter, op, packedSizes); - if (failed(packResult)) { + if (failed(maybePackResult)) { op->emitOpError("failed to pack the operation"); return failure(); } + + linalg::PackResult packResult = maybePackResult.value(); return packResult; } @@ -60,7 +64,8 @@ void AMDAIEPackAndTransposePass::runOnOperation() { // Find the linalg op for packing, currently only consider contraction ops linalg::LinalgOp linalgOp; funcOp->walk([&](linalg::LinalgOp op) { - if (linalg::isaContractionOpInterface(op)) { + if (linalg::isaContractionOpInterface(op) || + isa(op.getOperation())) { linalgOp = op; return WalkResult::interrupt(); } @@ -75,6 +80,7 @@ void AMDAIEPackAndTransposePass::runOnOperation() { // Step 1. Before packing the operation, we will prefetch the lowering and // packing config. auto config = getLoweringConfig(linalgOp); + auto packingConfig = getPackingConfig(linalgOp); if (!config || !packingConfig) { @@ -87,6 +93,12 @@ void AMDAIEPackAndTransposePass::runOnOperation() { // Extract packing config from the `linalgOp`. PackingConfigPackingLevelAttr packCfg = packingConfig.getPackingConfigVals(packLevel); + + if (!packCfg) { + funcOp->emitOpError("failed to get pack config for pack level ") + << packLevel; + return signalPassFailure(); + } SmallVector packedSizes = getAsIndexOpFoldResult(context, packCfg.getPackedSizes()); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp index ca657b796..34805dabd 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp @@ -472,54 +472,96 @@ static LogicalResult setRootConfigForPadPackPipeline( static LogicalResult setRootConfigForConvDecomposePipeline( mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp) { + MLIRContext *context = entryPointFn.getContext(); + FailureOr> maybeInstructionSize = getMatmulInstructionSize(linalgOp); int64_t OW = 4; int64_t OC = 4; int64_t IC = 8; if (succeeded(maybeInstructionSize)) { - auto instructionSize = maybeInstructionSize.value(); - OW = instructionSize[0]; - OC = instructionSize[1]; - IC = instructionSize[2]; + auto [m, n, k] = maybeInstructionSize.value(); + OW = m; + OC = n; + IC = k; } + SmallVector transposePackIndices{0, 1, 2}; + SmallVector unpackEmpty{false, false, true}; + + // Convolution type specific vectors: + SmallVector> innerPerm; + SmallVector> outerPerm; SmallVector tileSizeLevel0; SmallVector tileSizeLevel1; SmallVector tileSizeLevel2; - // Note: some of the tiling dimensions are hardcoded for now. - if (isa(linalgOp) || - isa(linalgOp)) { - // conv_2d_nhwc_hwcf tiling dims: [N, OH, OW, OC, KH, KW, IC]. - tileSizeLevel0 = {0, 4, OW, OC, 0, 0, 0}; + SmallVector packingSizes; + + // [N, OH, OW, OC, KH, KW, IC]. + if (isa(linalgOp) || + isa(linalgOp)) { + // The goal is to pack the input image and kernel as follows, when moving + // from L2 to L11: + // + // Example where input has 32 channels: + // + // %alloc_8 = memref.alloc() : memref<1x3x4x6x8xbf16, 2 : i32> + // iree_linalg_ext.pack %subview_5 outer_dims_perm = [0, 1, 3, 2] + // inner_dims_pos = [3] + // inner_tiles = [8] into %alloc_8 : + // (memref<1x3x6x32xbf16, strided<[576, 192, 32, 1], offset: ?>, 1 : i32> + // memref<1x3x4x6x8xbf16, 2 : i32>) + // + // %alloc_9 = memref.alloc() : memref<3x3x4x1x8x4xbf16, 2 : i32> + // iree_linalg_ext.pack %subview_6 outer_dims_perm = [0, 1, 2, 3] + // inner_dims_pos = [2, 3] + // inner_tiles = [8, 4] into %alloc_9 : + // (memref<3x3x32x4xbf16, strided<[384, 128, 4, 1], offset: ?>, 1 : i32> + // memref<3x3x4x1x8x4xbf16, 2 : i32>) + innerPerm = {{}, {{1, 0}}, {}}; + outerPerm = {{0, 1, 3, 2}, {}, {0, 1, 2, 3}}; + packingSizes = {0, 0, 0, OC, 0, 0, IC}; tileSizeLevel1 = {1, 1, OW, OC, 0, 0, 0}; - tileSizeLevel2 = {0, 0, 0, 0, 1, 1, IC}; - } else if (isa(linalgOp)) { - // conv_2d_nchw_fchw tiling dims: [N, OC, OH, OW, IC, KH, KW]. - tileSizeLevel0 = {0, OC, 4, OW, 0, 0, 0}; - tileSizeLevel1 = {1, OC, 1, OW, 0, 0, 0}; - tileSizeLevel2 = {0, 0, 0, 0, IC, 1, 1}; - } else if (isa(linalgOp)) { + // convert the kernel height, kernel width, and outer IC reduction into + // scf.for loops, leaving just a matmul of the instruction size inside + // the loops. + tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 1, 0, 0}; + } + + // [N, OC, OH, OW, IC, KH, KW] + else if (isa(linalgOp)) { + // The matmul reduction dimension is the input channel (IC) dimension. + // For Conv2DNhwcHwcfOp, this dimension is already the inner-most dimension + // of the input image, and the penultimate dimension of the kernel -- + // exactly what we want. For Conv2DNchwFchwOp, can the tensor dimensions be + // permuted in DMA to get them in the correct positions? For the image + // tensor, only if H*W is a nice power of 2 (DMA constraint). For kernels, + // it requires h*w is a nice power of 2 -- unlikely, we typically have + // h=w=3. The dimension permutations will therefore often therefore need to + // be done on the core. We leave this for future work, the expectation for + // now is that models have been transformed at a high level to avoid + // channel-first convolutions. + return linalgOp.emitError( + "Only channel-last convolution supported currently."); + } + + // [N, OH, OW, C, KW, HW] + else if (isa(linalgOp)) { // Notes: // ===== // - // An inherent property of depthwise convolutions is that they cannot be - // expressed in terms of matmuls, unlike the above (dense) conv-2ds. The - // tile sizes we choose below are therefore not constrained by the AIE - // matmul instructions. + // A property of depthwise convolution is that it can't be expressed in + // terms of matmul, unlike the above (dense) conv-2ds. The tile sizes we + // choose below are therefore not constrained by AIE matmul instructions. // // The logic is currently fragile, and there are no guardrails: there are // no checks that the data tiles are not too large, or that the input // dimensions are perfectly tiled by the hard-coded tile dimensions below. // These will be done as a follow-up task. - // - // - // Below we target a 4x4 array of AIE cores. auto getElementType = [](Value v) { return cast(v.getType()).getElementType(); }; const uint16_t OW_0 = 4; - const uint16_t OH_0 = 4; const uint16_t OH_1 = 1; auto operandType = getElementType(linalgOp->getOperand(0)); @@ -530,8 +572,8 @@ static LogicalResult setRootConfigForConvDecomposePipeline( OC_0 = maybeMacNumElements.value(); } // If the operand type has fewer than 32-bits, we really should be able to - // get a mac-width for it Bail because we didn't, and there's probably just - // something missing in the table. + // get a mac-width for it. Bail because we didn't, there's probably just + // something missing in a table. else if (operandType.getIntOrFloatBitWidth() < 32) { return linalgOp.emitError( "has an operand type with fewer than 32-bits, but no mac-width " @@ -539,17 +581,40 @@ static LogicalResult setRootConfigForConvDecomposePipeline( } const uint16_t OC_1 = OC_0 / 4; - - // depthwise_conv2d_nhwc_hwc tiling dims: - // [N, OH, OW, OC, KH,KW] - tileSizeLevel0 = {1, OH_0, OW_0, OC_0, 0, 0}; + packingSizes = {0, 0, 0, OC_1, 0, 0}; + innerPerm = {{}, {}, {}}; + outerPerm = {{0, 1, 2, 3}, {0, 1, 2}, {0, 1, 2, 3}}; tileSizeLevel1 = {1, OH_1, OW_0, OC_1, 0, 0}; - tileSizeLevel2 = {0, 0, 0, 0, 1, 1}; - } else { - assert(false && "Support must be added for this convolution op"); + tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 0}; } + + else { + return linalgOp.emitError( + "unrecognised convolution op, cannot set packing config. "); + } + + // For the objectFifo backend we currently target a single core. + // FIXME(newling) https://github.com/nod-ai/iree-amd-aie/issues/821 + tileSizeLevel0 = tileSizeLevel1; + + assert(!innerPerm.empty() && !outerPerm.empty() && !packingSizes.empty() && + !tileSizeLevel0.empty() && !tileSizeLevel1.empty() && + "not all vectors for initializing config are non-empty"); + + auto packingConfigLevel1Attr = getPackingConfigPackingLevelAttr( + context, packingSizes, transposePackIndices, unpackEmpty, innerPerm, + outerPerm); + SmallVector packingConfigLevelsVal{ + packingConfigLevel1Attr}; + + auto packingConfigLevels = + PackingConfigPackingLevelsAttr::get(context, packingConfigLevelsVal); + auto config = PackingConfigAttr::get(context, packingConfigLevels); + setPackingConfig(linalgOp, config); + TileSizesListType tileSizes = {tileSizeLevel0, tileSizeLevel1, tileSizeLevel2}; + return setOpConfigAndEntryPointFnTranslation( entryPointFn, linalgOp, tileSizes, IREE::Codegen::DispatchLoweringPassPipeline::Custom); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 75a40545c..785a26796 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -6,6 +6,8 @@ #include "iree-amd-aie/Transforms/Passes.h" +#include + #include "aie/Passes.h" #include "aievec/Passes.h" #include "air/Conversion/AIRLoweringPass.h" @@ -34,6 +36,7 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -408,16 +411,20 @@ void addConvDecomposePassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, bool enableVectorizationPasses, TilePassPipeline useTilePipeline) { + auto addCleanups = [&]() { + funcPassManager.addPass(createAMDAIECleanupPass()); + funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createCSEPass()); + }; + // First level tiling using scf.forall { AMDAIETileAndFuseOptions tileFuseOptions; tileFuseOptions.tilingLevel = 0; tileFuseOptions.useSCFFor = false; funcPassManager.addPass(createAMDAIETileAndFusePass(tileFuseOptions)); + addCleanups(); } - funcPassManager.addPass(createAMDAIECleanupPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); // Pad the linalg operation { @@ -441,67 +448,50 @@ void addConvDecomposePassPipeline(OpPassManager &funcPassManager, tileFuseOptions.tilingLevel = 1; tileFuseOptions.useSCFFor = false; funcPassManager.addPass(createAMDAIETileAndFusePass(tileFuseOptions)); + addCleanups(); } - funcPassManager.addPass(createAMDAIECleanupPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); // Fuse fill op into the inner forall loop funcPassManager.addPass(createAMDAIEFuseFillIntoForallPass()); - funcPassManager.addPass(createCanonicalizerPass()); - // Pad the linalg operation + // Pack the linalg operation { - AMDAIEPadOptions padOptions; - padOptions.paddingLevel = 1; - funcPassManager.addPass(createAMDAIEPadPass(padOptions)); + AMDAIEPackAndTransposeOptions packOptions; + packOptions.packLevel = 0; + funcPassManager.addPass(createAMDAIEPackAndTransposePass(packOptions)); } - // Only promote the result to local memory + // Promote the inputs and results to local memory { AMDAIEBufferizeToAllocationOptions bufferizeOptions; bufferizeOptions.memorySpace = 2; - bufferizeOptions.bufferizeOperand = BufferizeOperand::Output; + bufferizeOptions.bufferizeOperand = BufferizeOperand::InputOutput; funcPassManager.addPass( createAMDAIEBufferizeToAllocationPass(bufferizeOptions)); + addCleanups(); } - // Tile the reduction dimension using scf.for { AMDAIETileAndFuseOptions tileFuseOptions; tileFuseOptions.tilingLevel = 2; tileFuseOptions.useSCFFor = true; funcPassManager.addPass(createAMDAIETileAndFusePass(tileFuseOptions)); - } - funcPassManager.addPass(createAMDAIECleanupPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); - - // Pad the linalg operation - { - AMDAIEPadOptions padOptions; - padOptions.paddingLevel = 2; - funcPassManager.addPass(createAMDAIEPadPass(padOptions)); + addCleanups(); } - // Promote the inputs to local memory - { - AMDAIEBufferizeToAllocationOptions bufferizeOptions; - bufferizeOptions.memorySpace = 2; - bufferizeOptions.bufferizeOperand = BufferizeOperand::Input; - funcPassManager.addPass( - createAMDAIEBufferizeToAllocationPass(bufferizeOptions)); - } - - // Decompose Conv2d ops to Conv1d ops - funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass()); + LinalgFoldUnitExtentDimsPassOptions opts; + opts.useRankReducingSlices = true; + funcPassManager.addPass(mlir::createLinalgFoldUnitExtentDimsPass(opts)); // Vectorization passes + // FIXME(newling) https://github.com/nod-ai/iree-amd-aie/issues/820 + enableVectorizationPasses = false; appendVectorizationToPipeline(funcPassManager, enableVectorizationPasses); funcPassManager.addPass(createCanonicalizerPass()); // Comprehensive bufferization addAMDAIEBufferizePasses(funcPassManager, useTilePipeline); + funcPassManager.addPass(createHoistStaticallyBoundAllocationsPass()); } void buildAMDAIETransformPassPipeline( @@ -557,6 +547,9 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, bool enablePacketFlow) { passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); + // passManager.addPass(std::make_unique()); + passManager.addPass(createCanonicalizerPass()); + passManager.addPass(createAMDAIEConvertToDmaPass()); passManager.addPass(createAMDAIENormalizeLoopBoundsPass()); @@ -582,6 +575,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, passManager.addPass(createAMDAIEAssignLogicalObjectFifoDepthPass()); passManager.addPass(createAMDAIEAccessToAcquireReleasePass()); passManager.addPass(createAMDAIENoneAccessToTemporaryBufferPass()); + passManager.addPass( createAMDAIEAssignConnectionTypesPass({enablePacketFlow})); passManager.addPass(createCSEPass()); @@ -612,6 +606,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, passManager.addPass(createCanonicalizerPass()); passManager.addPass(createAMDAIEObjFifoBufferizationPass()); + passManager.addPass(createAMDAIETemporaryAllocBufferizationPass()); passManager.addPass(createAMDAIEConnectionToFlowPass()); passManager.addPass(createAMDAIEAssignPacketIdsPass()); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir index ad7b127ec..419fd4f05 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir @@ -1,32 +1,9 @@ // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-amdaie-lowering-strategy{use-pass-pipeline=conv-decompose})' %s | FileCheck %s -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config -#pipeline_layout = #hal.pipeline.layout, - , - -]> -builtin.module { - func.func @conv_2d_nchw_fchw_2x64x12x12x32x3x3_i32() { - %cst = arith.constant 0 : i32 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 32, 14, 14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x32x14x14xi32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [64, 32, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<64x32x3x3xi32> - %5 = tensor.empty() : tensor<2x64x12x12xi32> - %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<2x64x12x12xi32>) -> tensor<2x64x12x12xi32> - %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x32x14x14xi32>, tensor<64x32x3x3xi32>) outs(%6 : tensor<2x64x12x12xi32>) -> tensor<2x64x12x12xi32> - // CHECK: linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} - flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 64, 12, 12], strides = [1, 1, 1, 1] : tensor<2x64x12x12xi32> -> !flow.dispatch.tensor> - return - } -} -// ----- -// CHECK{LITERAL}: #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, , @@ -43,14 +20,17 @@ func.func @conv_static_dispatch_0_conv_2d_nhwc_hwcf_2x12x12x64x3x3x32_bf16xbf16x %5 = tensor.empty() : tensor<2x12x12x64xf32> %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x14x14x32xbf16>, tensor<3x3x32x64xbf16>) outs(%6 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> - // CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} + // CHECK: linalg.conv_2d_nhwc_hwcf + // CHECK-SAME: lowering_config = #config, + // CHECK-SAME: packing_config = #packingConfig, flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 12, 12, 64], strides = [1, 1, 1, 1] : tensor<2x12x12x64xf32> -> !flow.dispatch.tensor> return } // ----- -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, , @@ -67,15 +47,18 @@ func.func @conv_depthwise_channel_last_bf16(){ %5 = tensor.empty() : tensor<2x12x12x64xf32> %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> %7 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x14x14x64xbf16>, tensor<3x3x64xbf16>) outs(%6 : tensor<2x12x12x64xf32>) -> tensor<2x12x12x64xf32> - // CHECK: linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} + // CHECK: linalg.depthwise_conv_2d_nhwc_hwc + // CHECK-SAME: lowering_config = #config, + // CHECK-SAME: packing_config = #packingConfig, flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 12, 12, 64], strides = [1, 1, 1, 1] : tensor<2x12x12x64xf32> -> !flow.dispatch.tensor> return } // ----- -// Same test as above, but where the operand type is i8. In this case we expect OC tile size of 32 (not 16) at level 0, and 8 at levels 1 and 2. This is because of the instruction size of AIE. +// Same test as above, but where the operand type is i8. In this case we expect OC tile size 8 (not 4) at level 1. This is because of the instruction size of AIE. -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, , @@ -92,7 +75,9 @@ func.func @conv_depthwise_channel_last_i8(){ %5 = tensor.empty() : tensor<2x12x12x64xi32> %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<2x12x12x64xi32>) -> tensor<2x12x12x64xi32> %7 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<2x14x14x64xi8>, tensor<3x3x64xi8>) outs(%6 : tensor<2x12x12x64xi32>) -> tensor<2x12x12x64xi32> - // CHECK: linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, lowering_config = #config, strides = dense<1> : vector<2xi64>} + // CHECK: linalg.depthwise_conv_2d_nhwc_hwc + // CHECK-SAME: lowering_config = #config, + // CHECK-SAME: packing_config = #packingConfig, flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 12, 12, 64], strides = [1, 1, 1, 1] : tensor<2x12x12x64xi32> -> !flow.dispatch.tensor> return } From 4dc5ec23bd482e90506ebcee5acd88232c9c81c6 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 8 Oct 2024 14:13:53 -0700 Subject: [PATCH 2/4] cosmetic --- .../Transforms/KernelDispatch.cpp | 27 ++++--------------- .../iree-amd-aie/Transforms/Passes.cpp | 2 -- 2 files changed, 5 insertions(+), 24 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp index 34805dabd..440d77d92 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp @@ -501,30 +501,14 @@ static LogicalResult setRootConfigForConvDecomposePipeline( if (isa(linalgOp) || isa(linalgOp)) { // The goal is to pack the input image and kernel as follows, when moving - // from L2 to L11: - // - // Example where input has 32 channels: - // - // %alloc_8 = memref.alloc() : memref<1x3x4x6x8xbf16, 2 : i32> - // iree_linalg_ext.pack %subview_5 outer_dims_perm = [0, 1, 3, 2] - // inner_dims_pos = [3] - // inner_tiles = [8] into %alloc_8 : - // (memref<1x3x6x32xbf16, strided<[576, 192, 32, 1], offset: ?>, 1 : i32> - // memref<1x3x4x6x8xbf16, 2 : i32>) - // - // %alloc_9 = memref.alloc() : memref<3x3x4x1x8x4xbf16, 2 : i32> - // iree_linalg_ext.pack %subview_6 outer_dims_perm = [0, 1, 2, 3] - // inner_dims_pos = [2, 3] - // inner_tiles = [8, 4] into %alloc_9 : - // (memref<3x3x32x4xbf16, strided<[384, 128, 4, 1], offset: ?>, 1 : i32> - // memref<3x3x4x1x8x4xbf16, 2 : i32>) + // from L2 to L1 (example where there are 32 input channels): + // Image: memref<1x3x6x32xbf16> -> memref<1x3x4x6x8xbf16> + // Kernel: memref<3x3x32x4xbf16> -> memref<3x3x4x1x8x4xbf16> innerPerm = {{}, {{1, 0}}, {}}; outerPerm = {{0, 1, 3, 2}, {}, {0, 1, 2, 3}}; packingSizes = {0, 0, 0, OC, 0, 0, IC}; tileSizeLevel1 = {1, 1, OW, OC, 0, 0, 0}; - // convert the kernel height, kernel width, and outer IC reduction into - // scf.for loops, leaving just a matmul of the instruction size inside - // the loops. + // scf.for tiling of KH, KW, and (packed) IC dimensions: tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 1, 0, 0}; } @@ -547,9 +531,8 @@ static LogicalResult setRootConfigForConvDecomposePipeline( // [N, OH, OW, C, KW, HW] else if (isa(linalgOp)) { - // Notes: + // Notes // ===== - // // A property of depthwise convolution is that it can't be expressed in // terms of matmul, unlike the above (dense) conv-2ds. The tile sizes we // choose below are therefore not constrained by AIE matmul instructions. diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 785a26796..0a3191a0f 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -547,9 +547,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager, bool enablePacketFlow) { passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); - // passManager.addPass(std::make_unique()); passManager.addPass(createCanonicalizerPass()); - passManager.addPass(createAMDAIEConvertToDmaPass()); passManager.addPass(createAMDAIENormalizeLoopBoundsPass()); From 378491819cce635d43a14aa8d3b63d3fc40e8e99 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 8 Oct 2024 16:26:14 -0700 Subject: [PATCH 3/4] full column --- .../iree-amd-aie/Transforms/KernelDispatch.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp index 440d77d92..490175d35 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp @@ -507,6 +507,10 @@ static LogicalResult setRootConfigForConvDecomposePipeline( innerPerm = {{}, {{1, 0}}, {}}; outerPerm = {{0, 1, 3, 2}, {}, {0, 1, 2, 3}}; packingSizes = {0, 0, 0, OC, 0, 0, IC}; + // Target one column of 4 cores, each core processing a different + // output image row. TODO(newling) use 4x4 array. + // https://github.com/nod-ai/iree-amd-aie/issues/821 + tileSizeLevel0 = {1, 4, OW, OC, 0, 0, 0}; tileSizeLevel1 = {1, 1, OW, OC, 0, 0, 0}; // scf.for tiling of KH, KW, and (packed) IC dimensions: tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 1, 0, 0}; @@ -567,6 +571,10 @@ static LogicalResult setRootConfigForConvDecomposePipeline( packingSizes = {0, 0, 0, OC_1, 0, 0}; innerPerm = {{}, {}, {}}; outerPerm = {{0, 1, 2, 3}, {0, 1, 2}, {0, 1, 2, 3}}; + // Target one column of 4 cores, each core processing a different + // output image row. TODO(newling) use 4x4 array. + // https://github.com/nod-ai/iree-amd-aie/issues/821 + tileSizeLevel0 = {1, 4 * OH_1, OW_0, OC_0, 0, 0}; tileSizeLevel1 = {1, OH_1, OW_0, OC_1, 0, 0}; tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 0}; } @@ -576,10 +584,6 @@ static LogicalResult setRootConfigForConvDecomposePipeline( "unrecognised convolution op, cannot set packing config. "); } - // For the objectFifo backend we currently target a single core. - // FIXME(newling) https://github.com/nod-ai/iree-amd-aie/issues/821 - tileSizeLevel0 = tileSizeLevel1; - assert(!innerPerm.empty() && !outerPerm.empty() && !packingSizes.empty() && !tileSizeLevel0.empty() && !tileSizeLevel1.empty() && "not all vectors for initializing config are non-empty"); From f51056d30052d7bf0fe50a16251267e3b904b300 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 8 Oct 2024 16:51:00 -0700 Subject: [PATCH 4/4] fixes for full column, review comment addressing --- .../iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp | 3 --- .../AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp | 2 +- .../target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp | 3 --- .../Transforms/test/lowering_strategy_conv.mlir | 6 +++--- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp index 0654b9dd1..62544391e 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEPackAndTranspose.cpp @@ -7,11 +7,8 @@ #include "iree-amd-aie/IR/AMDAIEAttrs.h" #include "iree-amd-aie/Transforms/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" #define DEBUG_TYPE "iree-amdaie-pack-and-transpose" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp index 490175d35..398298120 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp @@ -574,7 +574,7 @@ static LogicalResult setRootConfigForConvDecomposePipeline( // Target one column of 4 cores, each core processing a different // output image row. TODO(newling) use 4x4 array. // https://github.com/nod-ai/iree-amd-aie/issues/821 - tileSizeLevel0 = {1, 4 * OH_1, OW_0, OC_0, 0, 0}; + tileSizeLevel0 = {1, 4 * OH_1, OW_0, OC_1, 0, 0}; tileSizeLevel1 = {1, OH_1, OW_0, OC_1, 0, 0}; tileSizeLevel2 = {0, 0, 0, 0, 1, 1, 0}; } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 0a3191a0f..b5044f592 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -6,8 +6,6 @@ #include "iree-amd-aie/Transforms/Passes.h" -#include - #include "aie/Passes.h" #include "aievec/Passes.h" #include "air/Conversion/AIRLoweringPass.h" @@ -36,7 +34,6 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir index 419fd4f05..d07c3d136 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_conv.mlir @@ -2,7 +2,7 @@ -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config // CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, @@ -29,7 +29,7 @@ func.func @conv_static_dispatch_0_conv_2d_nhwc_hwcf_2x12x12x64x3x3x32_bf16xbf16x // ----- -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config // CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout, @@ -57,7 +57,7 @@ func.func @conv_depthwise_channel_last_bf16(){ // ----- // Same test as above, but where the operand type is i8. In this case we expect OC tile size 8 (not 4) at level 1. This is because of the instruction size of AIE. -// CHECK{LITERAL}: #config = #iree_codegen.lowering_config +// CHECK{LITERAL}: #config = #iree_codegen.lowering_config // CHECK{LITERAL}: #packingConfig = #amdaie.packing_config #pipeline_layout = #hal.pipeline.layout,