From cc2ddb1ded37868ce2015d9511dc9bba9fe115eb Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 21 Jan 2022 11:17:21 +0900 Subject: [PATCH] Add one rule for swapping ReshapeOp and MatMulOp (#1105) * Add one rule for removing reshapeop Signed-off-by: Tung D. Le * Use a simple rule Signed-off-by: Tung D. Le * Remove an obsolete comment Signed-off-by: Tung D. Le * Extra comment removed Signed-off-by: Tung D. Le --- src/Dialect/ONNX/Rewrite.cpp | 11 +++++++ src/Dialect/ONNX/Rewrite.td | 23 +++++++++++++++ test/mlir/onnx/onnx_canonicalization.mlir | 36 +++++++++++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/src/Dialect/ONNX/Rewrite.cpp b/src/Dialect/ONNX/Rewrite.cpp index 10f7393d3e6..502c1d64b89 100644 --- a/src/Dialect/ONNX/Rewrite.cpp +++ b/src/Dialect/ONNX/Rewrite.cpp @@ -53,6 +53,16 @@ ArrayAttr createArrayAttrOfNToM(PatternRewriter &rewriter, int N, int M) { return rewriter.getI64ArrayAttr(vals); } +// Get return type for a MatMulOp whose A's rank is N (>2) and B's rank is 2. +Type getReturnTypeForMatMulOpND2D(Value A, Value B) { + ArrayRef aShape = A.getType().cast().getShape(); + ArrayRef bShape = B.getType().cast().getShape(); + SmallVector resShape(aShape.begin(), aShape.end() - 1); + resShape.emplace_back(bShape[bShape.size() - 1]); + return RankedTensorType::get( + resShape, A.getType().cast().getElementType()); +} + /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Dialect/ONNX/ONNXRewrite.inc" @@ -93,6 +103,7 @@ void ONNXReshapeOp::getCanonicalizationPatterns( RewritePatternSet &result, MLIRContext *context) { result.insert(context); result.insert(context); + result.insert(context); } /// on the ONNXDropoutOp. diff --git a/src/Dialect/ONNX/Rewrite.td b/src/Dialect/ONNX/Rewrite.td index 2c0be258c18..badd4de315c 100644 --- a/src/Dialect/ONNX/Rewrite.td +++ b/src/Dialect/ONNX/Rewrite.td @@ -109,6 +109,14 @@ class HasRankOf : Constraint() && " "$0.getType().cast().getRank() == " # rank>>; +def HaveSameLastDim: Constraint< + CPred<"hasShapeAndRank($0) && hasShapeAndRank($1) && " + "($0.getType().cast().getShape()" + "[$0.getType().cast().getRank() - 1] == " + "$1.getType().cast().getShape()" + "[$1.getType().cast().getRank() - 1])">, + "Two tensors have the same last dimension">; + def GetUnitAttr: NativeCodeCall<"$_builder.getUnitAttr()">; def HasOneUse : Constraint>; @@ -312,6 +320,21 @@ def RemoveIdentityReshapePattern: Pat< // Check that val has the specified shape. [(HasSpecifiedConstantShape $val, $shape)]>; +def GetReturnTypeForMatMulOpND2D: NativeCodeCall< + "getReturnTypeForMatMulOpND2D($0, $1)" +>; + +def SwapReshapeMatMulPattern: Pattern< + // If the input of Reshape is suitable for the next MatMul, use the input directly + // for MatMul. In other words, swapping Reshape and MatMul. + // This rule will put reshape ops together, then the reshape fusion rule can be applied. + (ONNXMatMulOp:$res2 (ONNXReshapeOp:$res1 $A, $_), $B), + [(ONNXReshapeOp (ONNXMatMulOp $A, $B, (returnType (GetReturnTypeForMatMulOpND2D $A, $B))), + (ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromShape $res2)))], + [(HasRankGT<2> $A), (HasRankOf<2> $res1), (HasRankOf<2> $B), // A is reshaped to 2D. + (HaveSameLastDim $A, $res1) // The last dim of A is unchanged by reshape. + ] +>; //===----------------------------------------------------------------------===// // Canonicalization for ONNXSqueezeOp diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 488be985410..a996449a9e2 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -180,6 +180,42 @@ func @test_reshape_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13 // ----- +// Check the removal of reshapes that are used for matmul's input and output. +// The 1st reshape is moved down then fused with the 2nd reshape to become an identity. + +// CHECK-LABEL: func @test_reshape_removal_with_matmul_4D(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<3x5x10x1xf32> { +func @test_reshape_removal_with_matmul_4D(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<3x5x10x1xf32> { + %shape1 = "onnx.Constant"() { value = dense<[150, 20]> : tensor<2xi64> } : () -> tensor<2xi64> + %shape2 = "onnx.Constant"() { value = dense<[3, 5, 10, 1]> : tensor<4xi64> } : () -> tensor<4xi64> + %0 = "onnx.Reshape"(%arg0, %shape1) : (tensor<3x5x10x20xf32>, tensor<2xi64>) -> tensor<150x20xf32> + %1 = "onnx.MatMul"(%0, %arg1) : (tensor<150x20xf32>, tensor<20x1xf32>) -> tensor<150x1xf32> + %2 = "onnx.Reshape"(%1, %shape2) : (tensor<150x1xf32>, tensor<4xi64>) -> tensor<3x5x10x1xf32> + return %2 : tensor<3x5x10x1xf32> + // CHECK-NEXT: "onnx.MatMul"(%arg0, %arg1) : (tensor<3x5x10x20xf32>, tensor<20x1xf32>) -> tensor<3x5x10x1xf32> + // CHECK-NOT: "onnx.Reshape" +} + +// ----- + +// Check the removal of reshapes that are used for matmul's input and output. +// The 1st reshape is moved down then fused with the 2nd reshape to become a single reshape. + +// CHECK-LABEL: func @test_reshape_should_not_remove(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<15x10x1xf32> { +func @test_reshape_should_not_remove(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<15x10x1xf32> { + %shape1 = "onnx.Constant"() { value = dense<[150, 20]> : tensor<2xi64> } : () -> tensor<2xi64> + %shape2 = "onnx.Constant"() { value = dense<[15, 10, 1]> : tensor<3xi64> } : () -> tensor<3xi64> + %0 = "onnx.Reshape"(%arg0, %shape1) : (tensor<3x5x10x20xf32>, tensor<2xi64>) -> tensor<150x20xf32> + %1 = "onnx.MatMul"(%0, %arg1) : (tensor<150x20xf32>, tensor<20x1xf32>) -> tensor<150x1xf32> + %2 = "onnx.Reshape"(%1, %shape2) : (tensor<150x1xf32>, tensor<3xi64>) -> tensor<15x10x1xf32> + return %2 : tensor<15x10x1xf32> + // CHECK: [[CST:%.+]] = "onnx.Constant"() {value = dense<[15, 10, 1]> : tensor<3xi64>} : () -> tensor<3xi64> + // CHECK: [[MATMUL:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<3x5x10x20xf32>, tensor<20x1xf32>) -> tensor<3x5x10x1xf32> + // CHECK: [[RES:%.+]] = "onnx.Reshape"([[MATMUL]], [[CST]]) : (tensor<3x5x10x1xf32>, tensor<3xi64>) -> tensor<15x10x1xf32> + // CHECK: return [[RES]] : tensor<15x10x1xf32> +} + +// ----- + // Check the combining of transposes into a simple transpose. // CHECK-LABEL: func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> { func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> {