diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index ad24532fa..36d74939f 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -53,6 +53,18 @@ def VectorizationPass : Pass<"vectorization-pass", let dependentDialects = [ "memref::MemRefDialect", "linalg::LinalgDialect", "vector::VectorDialect" ]; } + + +def HoistVectorTransfers : Pass<"hoist-vector-transfer"> { + let summary = "Hoist vector transfer operation outside of reduction and k loop"; + let description = [{ + Hoists the vector transfer read and write operations of the resultant matrix outside the reduction and k loop for a brgemm operation. This pass should be applied after the BrgemmLinalgTiling Pass. + }]; + let dependentDialects = [ "vector::VectorDialect", "scf::SCFDialect" ]; +} + + + def VectorContractToOuterproduct : Pass< "vector-contract-to-outerproduct"> { let summary = "Perform outerproduct lowering of vector contraction ops"; diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index b5e27f6c9..9f9442fee 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_library(TPPTransforms Vectorization.cpp SplitReductionDim.cpp VectorContractToOuterproduct.cpp + HoistVectorTransfers.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/HoistVectorTransfers.cpp b/lib/TPP/Transforms/HoistVectorTransfers.cpp new file mode 100644 index 000000000..48e07dec8 --- /dev/null +++ b/lib/TPP/Transforms/HoistVectorTransfers.cpp @@ -0,0 +1,164 @@ +//===-HoistVectorTransfers.cpp -----------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements tile configuration hoisting on parallel loops. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_HOISTVECTORTRANSFERS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; +using namespace vector; + +namespace mlir { +namespace tpp { + +struct HoistVectorTransferOp : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + // Check whether the linalg tiling + vector contract pattern matches + auto retriveVectorReadOp = contractOp.getAcc().getDefiningOp(); + if (retriveVectorReadOp == NULL) + return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation"); + + auto subviewOp = retriveVectorReadOp.getOperand(0).getDefiningOp(); + if (subviewOp == NULL) + return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation"); + + auto ReductionForOp = llvm::dyn_cast(subviewOp->getNextNode()); + if (ReductionForOp == NULL) + return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation"); + + auto KForOp = llvm::dyn_cast(ReductionForOp.getBody()->front()); + if (KForOp == NULL) + return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation"); + + // Move the vector transfer read before the reduction and k loop + rewriter.setInsertionPointAfter(subviewOp); + auto *cloneVectorReadOp = rewriter.clone(*retriveVectorReadOp); + retriveVectorReadOp.replaceAllUsesWith(cloneVectorReadOp); + + // Code to re-create the reduction and k loop with iter args + auto *nextOp = (*cloneVectorReadOp).getNextNode(); + auto oldReductionForOp = llvm::dyn_cast(*nextOp); + auto oldKForOp = llvm::dyn_cast(oldReductionForOp.getBody()->front()); + + auto vectorReadOpValue = (*cloneVectorReadOp).getResult(0); + rewriter.setInsertionPoint(oldReductionForOp); + + auto newReductionForOp = rewriter.create( + oldReductionForOp.getLoc(), oldReductionForOp.getLowerBound(), oldReductionForOp.getUpperBound(), + oldReductionForOp.getStep(),ValueRange{vectorReadOpValue}, + [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, Value ivNewReductionForOp, + ValueRange iterArgsNewReductionForOp) { + + auto newKForOp = rewriter.create( + oldKForOp.getLoc(), oldKForOp.getLowerBound(), oldKForOp.getUpperBound(), + oldKForOp.getStep(), iterArgsNewReductionForOp, + [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, Value ivNewKForOp, + ValueRange iterArgsNewKForOp) { + + mlir::IRMapping mapper; + mapper.map(oldReductionForOp.getInductionVar(), ivNewReductionForOp); + mapper.map(oldKForOp.getInductionVar(), ivNewKForOp); + + for (auto [origArgReduction, newArgReduction] : + llvm::zip(oldReductionForOp.getRegionIterArgs(), iterArgsNewReductionForOp)) { + mapper.map(origArgReduction, newArgReduction); + } + + for (auto [origArgK, newArgK] : + llvm::zip(oldKForOp.getRegionIterArgs(), iterArgsNewKForOp)) { + mapper.map(origArgK, newArgK); + } + + for (auto &op : oldKForOp.getBody()->without_terminator()) { + rewriterNewKForOp.clone(op, mapper); + } + + rewriterNewKForOp.create(locNewKForOp, iterArgsNewKForOp); + + }); + rewriterNewReductionForOp.create(locNewReductionForOp, newKForOp.getResult(0)); + }); + + //Code to hoist vector transfer write after reduction loop and also to update the yield of k loop + auto newKForOp = llvm::dyn_cast(newReductionForOp.getBody()->front()); + Value newcontractOpValue; + mlir::vector::TransferWriteOp vectorWriteOperation; + mlir::Block *bodyBlock = newKForOp.getBody(); + for (auto &op : bodyBlock->getOperations()) { + if (auto vectorContractOp = llvm::dyn_cast(op)) { + vectorContractOp.setOperand(vectorContractOp.getNumOperands()-1, newKForOp.getRegionIterArgs()[0]); + newcontractOpValue = vectorContractOp.getResult(); + } + if (auto yieldOp = llvm::dyn_cast(op)) { + if ( newcontractOpValue != NULL) + yieldOp.setOperand(0, newcontractOpValue); + } + if (auto vectorWriteOp = llvm::dyn_cast(op)) { + vectorWriteOperation = vectorWriteOp; + } + } + + if (vectorWriteOperation != NULL) { + vectorWriteOperation.setOperand(0,newReductionForOp.getResult(0)); + vectorWriteOperation->moveBefore(oldReductionForOp); + } + + // Erase the old vector contract operation + for (auto result : contractOp->getResults()) { + for (auto *userOp : result.getUsers()) { + userOp->erase(); + } + } + contractOp.erase(); + + return success(); + } +}; + + +void populateHoistVectorTransferPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +struct HoistVectorTransfers + : public impl::HoistVectorTransfersBase { + using HoistVectorTransfersBase::HoistVectorTransfersBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateHoistVectorTransferPatterns(patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } +}; +} // namespace tpp +} // namespace mlir diff --git a/test/Integration/hoist-vector-transfer-brgemm.mlir b/test/Integration/hoist-vector-transfer-brgemm.mlir new file mode 100644 index 000000000..0d4cf34ed --- /dev/null +++ b/test/Integration/hoist-vector-transfer-brgemm.mlir @@ -0,0 +1,59 @@ +// RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -print > %t.1 +// RUN: tpp-opt %s --loop-invariant-code-motion --vectorization-pass --loop-invariant-code-motion --hoist-vector-transfer | tpp-run -e entry --entry-point-result=void -print > %t.2 +// RUN: diff %t.1 %t.2 + + memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @entry(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%subview : memref<32x64xf32, strided<[64, 1], offset: ?>>) + %subview_0 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c64 step %c64 { + %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> + scf.for %arg5 = %c0 to %c24 step %c1 { + scf.for %arg6 = %c0 to %c64 step %c1 { + %subview_2 = memref.subview %subview_0[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> + %subview_3 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview_2, %subview_3 : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>) outs(%subview_1 : memref<4x64xf32, strided<[64, 1], offset: ?>>) + } + } + } + } + } + return %alloc : memref<8x24x32x64xf32> + } + +// ----- + +// RUN: tpp-opt %s | tpp-run -e nomatch --entry-point-result=void -seed 123 -print > %t.1 +// RUN: tpp-opt %s --hoist-vector-transfer | tpp-run -e nomatch --entry-point-result=void -seed 123 -print > %t.2 +// RUN: diff %t.1 %t.2 + +#permA0 = affine_map<(d0, d1, d2) -> (d2, d0)> +#permA1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#permA2 = affine_map<(d0, d1, d2) -> (d0, d1)> + +func.func @nomatch(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> + %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> + %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> + %3 = vector.contract {indexing_maps = [#permA0, #permA1, #permA2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %2 + : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> + %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> + return %4 : tensor<4x4xf32> +} + diff --git a/test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir b/test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir new file mode 100644 index 000000000..45a68a9d7 --- /dev/null +++ b/test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir @@ -0,0 +1,376 @@ +// RUN: tpp-opt %s --hoist-vector-transfer --split-input-file | FileCheck %s + + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +module { + memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @entry(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c64 step %c64 { + %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> + scf.for %arg5 = %c0 to %c24 step %c1 { + scf.for %arg6 = %c0 to %c64 step %c1 { + %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> + %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> + %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> + %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> + vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> + } + } + } + } + } + return %alloc : memref<8x24x32x64xf32> + } +} + + + + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + +// CHECK-LABEL: func.func @entry( +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> +// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> +// CHECK: scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (8, 24) { +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_11]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_5]] { +// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> +// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<4x64xf32>) { +// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<4x64xf32>) { +// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_26:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> +// CHECK: %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> +// CHECK: %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> +// CHECK: %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> +// CHECK: scf.yield %[[VAL_29]] : vector<4x64xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_22]] : vector<4x64xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_10]] : memref<8x24x32x64xf32> +// CHECK: } + + + + +// ----- + +// RUN: tpp-opt %s --hoist-vector-transfer --split-input-file | FileCheck %s + + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +module { + memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @entry(%arg0: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x32xf32> + %c1 = arith.constant 1 : index + %c48 = arith.constant 48 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c32 step %c2 { + %subview_3 = memref.subview %subview[%arg3, %arg4] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> + scf.for %arg5 = %c0 to %c48 step %c1 { + scf.for %arg6 = %c0 to %c32 step %c4 { + %subview_4 = memref.subview %subview_2[%arg5, %arg3, %arg6] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> + %subview_5 = memref.subview %0[%arg5, %arg6, %arg4] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> + %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x4xf32> + %2 = vector.transfer_read %subview_5[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x2xf32> + %3 = vector.transfer_read %subview_3[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x2xf32, strided<[32, 1], offset: ?>>, vector<4x2xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x4xf32>, vector<1x4x2xf32> into vector<4x2xf32> + vector.transfer_write %4, %subview_3[%c0, %c0] {in_bounds = [true, true]} : vector<4x2xf32>, memref<4x2xf32, strided<[32, 1], offset: ?>> + } + } + } + } + } + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc_1[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c32 step %c2 { + %subview_3 = memref.subview %subview[%arg3, %arg4] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> + scf.for %arg5 = %c0 to %c48 step %c1 { + scf.for %arg6 = %c0 to %c32 step %c4 { + %subview_4 = memref.subview %subview_2[%arg5, %arg3, %arg6] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> + %subview_5 = memref.subview %0[%arg5, %arg6, %arg4] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> + %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x4xf32> + %2 = vector.transfer_read %subview_5[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x2xf32> + %3 = vector.transfer_read %subview_3[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x2xf32, strided<[32, 1], offset: ?>>, vector<4x2xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x4xf32>, vector<1x4x2xf32> into vector<4x2xf32> + vector.transfer_write %4, %subview_3[%c0, %c0] {in_bounds = [true, true]} : vector<4x2xf32>, memref<4x2xf32, strided<[32, 1], offset: ?>> + } + } + } + } + } + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c32 step %c2 { + %subview_3 = memref.subview %subview[%arg3, %arg4] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> + scf.for %arg5 = %c0 to %c48 step %c1 { + scf.for %arg6 = %c0 to %c32 step %c4 { + %subview_4 = memref.subview %subview_2[%arg5, %arg3, %arg6] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> + %subview_5 = memref.subview %0[%arg5, %arg6, %arg4] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> + %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x4xf32> + %2 = vector.transfer_read %subview_5[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x2xf32> + %3 = vector.transfer_read %subview_3[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x2xf32, strided<[32, 1], offset: ?>>, vector<4x2xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x4xf32>, vector<1x4x2xf32> into vector<4x2xf32> + vector.transfer_write %4, %subview_3[%c0, %c0] {in_bounds = [true, true]} : vector<4x2xf32>, memref<4x2xf32, strided<[32, 1], offset: ?>> + } + } + } + } + } + return %alloc : memref<8x48x32x32xf32> + } +} + + + + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + + +// CHECK-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} + +// CHECK-LABEL: func.func @entry( +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x32xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 48 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> +// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> +// CHECK: scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (8, 48) { +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_11]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_5]] { +// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x2xf32, strided<[32, 1], offset: ?>>, vector<4x2xf32> +// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<4x2xf32>) { +// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<4x2xf32>) { +// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]]] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_26:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]]] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x4xf32> +// CHECK: %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x2xf32> +// CHECK: %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x4xf32>, vector<1x4x2xf32> into vector<4x2xf32> +// CHECK: scf.yield %[[VAL_29]] : vector<4x2xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_22]] : vector<4x2xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x2xf32>, memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_30:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> +// CHECK: scf.forall (%[[VAL_31:.*]], %[[VAL_32:.*]]) in (8, 48) { +// CHECK: %[[VAL_33:.*]] = memref.subview %[[VAL_30]]{{\[}}%[[VAL_31]], %[[VAL_32]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_33]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: %[[VAL_34:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_31]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: scf.for %[[VAL_35:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_36:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_5]] { +// CHECK: %[[VAL_37:.*]] = memref.subview %[[VAL_33]]{{\[}}%[[VAL_35]], %[[VAL_36]]] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: %[[VAL_38:.*]] = vector.transfer_read %[[VAL_37]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x2xf32, strided<[32, 1], offset: ?>>, vector<4x2xf32> +// CHECK: %[[VAL_39:.*]] = scf.for %[[VAL_40:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_41:.*]] = %[[VAL_38]]) -> (vector<4x2xf32>) { +// CHECK: %[[VAL_42:.*]] = scf.for %[[VAL_43:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] iter_args(%[[VAL_44:.*]] = %[[VAL_41]]) -> (vector<4x2xf32>) { +// CHECK: %[[VAL_45:.*]] = memref.subview %[[VAL_34]]{{\[}}%[[VAL_40]], %[[VAL_35]], %[[VAL_43]]] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_46:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_40]], %[[VAL_43]], %[[VAL_36]]] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_47:.*]] = vector.transfer_read %[[VAL_45]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x4xf32> +// CHECK: %[[VAL_48:.*]] = vector.transfer_read %[[VAL_46]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x2xf32> +// CHECK: %[[VAL_49:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x4xf32>, vector<1x4x2xf32> into vector<4x2xf32> +// CHECK: scf.yield %[[VAL_49]] : vector<4x2xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_42]] : vector<4x2xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[VAL_39]], %[[VAL_37]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x2xf32>, memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: scf.forall (%[[VAL_50:.*]], %[[VAL_51:.*]]) in (8, 48) { +// CHECK: %[[VAL_52:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_50]], %[[VAL_51]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_52]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: %[[VAL_53:.*]] = memref.subview %[[VAL_30]]{{\[}}%[[VAL_50]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: scf.for %[[VAL_54:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_55:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_5]] { +// CHECK: %[[VAL_56:.*]] = memref.subview %[[VAL_52]]{{\[}}%[[VAL_54]], %[[VAL_55]]] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: %[[VAL_57:.*]] = vector.transfer_read %[[VAL_56]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x2xf32, strided<[32, 1], offset: ?>>, vector<4x2xf32> +// CHECK: %[[VAL_58:.*]] = scf.for %[[VAL_59:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_60:.*]] = %[[VAL_57]]) -> (vector<4x2xf32>) { +// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] iter_args(%[[VAL_63:.*]] = %[[VAL_60]]) -> (vector<4x2xf32>) { +// CHECK: %[[VAL_64:.*]] = memref.subview %[[VAL_53]]{{\[}}%[[VAL_59]], %[[VAL_54]], %[[VAL_62]]] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_65:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_59]], %[[VAL_62]], %[[VAL_55]]] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_66:.*]] = vector.transfer_read %[[VAL_64]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x4xf32> +// CHECK: %[[VAL_67:.*]] = vector.transfer_read %[[VAL_65]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>, vector<1x4x2xf32> +// CHECK: %[[VAL_68:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x4xf32>, vector<1x4x2xf32> into vector<4x2xf32> +// CHECK: scf.yield %[[VAL_68]] : vector<4x2xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_61]] : vector<4x2xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[VAL_58]], %[[VAL_56]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x2xf32>, memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_10]] : memref<8x48x32x32xf32> +// CHECK: } + + +// ----- + +// RUN: tpp-opt %s --hoist-vector-transfer --split-input-file | FileCheck %s + + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +module { + memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @nomatch(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32> + %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32> + %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32> + vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + } + return %alloc : memref<8x24x32x64xf32> + } +} + + + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + +// CHECK-LABEL: func.func @nomatch( +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> +// CHECK: %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> +// CHECK: scf.forall (%[[VAL_6:.*]], %[[VAL_7:.*]]) in (8, 24) { +// CHECK: %[[VAL_8:.*]] = memref.subview %[[VAL_5]]{{\[}}%[[VAL_6]], %[[VAL_7]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_9:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_6]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_10:.*]] = vector.transfer_read %[[VAL_9]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32> +// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32> +// CHECK: %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32> +// CHECK: %[[VAL_13:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32> +// CHECK: vector.transfer_write %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: } +// CHECK: return %[[VAL_5]] : memref<8x24x32x64xf32> +// CHECK: } + +// ----- + +// RUN: tpp-opt %s --hoist-vector-transfer --split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> + +module { + func.func @nomatch1(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> + %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32> + %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32> + %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32> + %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32> + return %4 : tensor<4x64xf32> + } +} + + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + + +// CHECK-LABEL: func.func @nomatch1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x64xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x64xf32>) -> tensor<4x64xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> +// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32> +// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32> +// CHECK: %[[VAL_8:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32> +// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32> +// CHECK: return %[[VAL_9]] : tensor<4x64xf32> +// CHECK: }