Skip to content

Commit

Permalink
Add one rule for swapping ReshapeOp and MatMulOp (llvm#1105)
Browse files Browse the repository at this point in the history
* Add one rule for removing reshapeop

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

* Use a simple rule

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

* Remove an obsolete comment

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

* Extra comment removed

Signed-off-by: Tung D. Le <tung@jp.ibm.com>
  • Loading branch information
tungld authored Jan 21, 2022
1 parent bf0176f commit cc2ddb1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ ArrayAttr createArrayAttrOfNToM(PatternRewriter &rewriter, int N, int M) {
return rewriter.getI64ArrayAttr(vals);
}

// Get return type for a MatMulOp whose A's rank is N (>2) and B's rank is 2.
Type getReturnTypeForMatMulOpND2D(Value A, Value B) {
ArrayRef<int64_t> aShape = A.getType().cast<RankedTensorType>().getShape();
ArrayRef<int64_t> bShape = B.getType().cast<RankedTensorType>().getShape();
SmallVector<int64_t> resShape(aShape.begin(), aShape.end() - 1);
resShape.emplace_back(bShape[bShape.size() - 1]);
return RankedTensorType::get(
resShape, A.getType().cast<ShapedType>().getElementType());
}

/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Dialect/ONNX/ONNXRewrite.inc"

Expand Down Expand Up @@ -93,6 +103,7 @@ void ONNXReshapeOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<FuseReshapePattern>(context);
result.insert<RemoveIdentityReshapePattern>(context);
result.insert<SwapReshapeMatMulPattern>(context);
}

/// on the ONNXDropoutOp.
Expand Down
23 changes: 23 additions & 0 deletions src/Dialect/ONNX/Rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ class HasRankOf<int rank> :
Constraint<CPred<"$0.getType().isa<ShapedType>() && "
"$0.getType().cast<ShapedType>().getRank() == " # rank>>;

def HaveSameLastDim: Constraint<
CPred<"hasShapeAndRank($0) && hasShapeAndRank($1) && "
"($0.getType().cast<RankedTensorType>().getShape()"
"[$0.getType().cast<RankedTensorType>().getRank() - 1] == "
"$1.getType().cast<RankedTensorType>().getShape()"
"[$1.getType().cast<RankedTensorType>().getRank() - 1])">,
"Two tensors have the same last dimension">;

def GetUnitAttr: NativeCodeCall<"$_builder.getUnitAttr()">;

def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
Expand Down Expand Up @@ -312,6 +320,21 @@ def RemoveIdentityReshapePattern: Pat<
// Check that val has the specified shape.
[(HasSpecifiedConstantShape $val, $shape)]>;

def GetReturnTypeForMatMulOpND2D: NativeCodeCall<
"getReturnTypeForMatMulOpND2D($0, $1)"
>;

def SwapReshapeMatMulPattern: Pattern<
// If the input of Reshape is suitable for the next MatMul, use the input directly
// for MatMul. In other words, swapping Reshape and MatMul.
// This rule will put reshape ops together, then the reshape fusion rule can be applied.
(ONNXMatMulOp:$res2 (ONNXReshapeOp:$res1 $A, $_), $B),
[(ONNXReshapeOp (ONNXMatMulOp $A, $B, (returnType (GetReturnTypeForMatMulOpND2D $A, $B))),
(ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromShape $res2)))],
[(HasRankGT<2> $A), (HasRankOf<2> $res1), (HasRankOf<2> $B), // A is reshaped to 2D.
(HaveSameLastDim $A, $res1) // The last dim of A is unchanged by reshape.
]
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXSqueezeOp
Expand Down
36 changes: 36 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,42 @@ func @test_reshape_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13

// -----

// Check the removal of reshapes that are used for matmul's input and output.
// The 1st reshape is moved down then fused with the 2nd reshape to become an identity.

// CHECK-LABEL: func @test_reshape_removal_with_matmul_4D(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<3x5x10x1xf32> {
func @test_reshape_removal_with_matmul_4D(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<3x5x10x1xf32> {
%shape1 = "onnx.Constant"() { value = dense<[150, 20]> : tensor<2xi64> } : () -> tensor<2xi64>
%shape2 = "onnx.Constant"() { value = dense<[3, 5, 10, 1]> : tensor<4xi64> } : () -> tensor<4xi64>
%0 = "onnx.Reshape"(%arg0, %shape1) : (tensor<3x5x10x20xf32>, tensor<2xi64>) -> tensor<150x20xf32>
%1 = "onnx.MatMul"(%0, %arg1) : (tensor<150x20xf32>, tensor<20x1xf32>) -> tensor<150x1xf32>
%2 = "onnx.Reshape"(%1, %shape2) : (tensor<150x1xf32>, tensor<4xi64>) -> tensor<3x5x10x1xf32>
return %2 : tensor<3x5x10x1xf32>
// CHECK-NEXT: "onnx.MatMul"(%arg0, %arg1) : (tensor<3x5x10x20xf32>, tensor<20x1xf32>) -> tensor<3x5x10x1xf32>
// CHECK-NOT: "onnx.Reshape"
}

// -----

// Check the removal of reshapes that are used for matmul's input and output.
// The 1st reshape is moved down then fused with the 2nd reshape to become a single reshape.

// CHECK-LABEL: func @test_reshape_should_not_remove(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<15x10x1xf32> {
func @test_reshape_should_not_remove(%arg0: tensor<3x5x10x20xf32>, %arg1: tensor<20x1xf32>) -> tensor<15x10x1xf32> {
%shape1 = "onnx.Constant"() { value = dense<[150, 20]> : tensor<2xi64> } : () -> tensor<2xi64>
%shape2 = "onnx.Constant"() { value = dense<[15, 10, 1]> : tensor<3xi64> } : () -> tensor<3xi64>
%0 = "onnx.Reshape"(%arg0, %shape1) : (tensor<3x5x10x20xf32>, tensor<2xi64>) -> tensor<150x20xf32>
%1 = "onnx.MatMul"(%0, %arg1) : (tensor<150x20xf32>, tensor<20x1xf32>) -> tensor<150x1xf32>
%2 = "onnx.Reshape"(%1, %shape2) : (tensor<150x1xf32>, tensor<3xi64>) -> tensor<15x10x1xf32>
return %2 : tensor<15x10x1xf32>
// CHECK: [[CST:%.+]] = "onnx.Constant"() {value = dense<[15, 10, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
// CHECK: [[MATMUL:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<3x5x10x20xf32>, tensor<20x1xf32>) -> tensor<3x5x10x1xf32>
// CHECK: [[RES:%.+]] = "onnx.Reshape"([[MATMUL]], [[CST]]) : (tensor<3x5x10x1xf32>, tensor<3xi64>) -> tensor<15x10x1xf32>
// CHECK: return [[RES]] : tensor<15x10x1xf32>
}

// -----

// Check the combining of transposes into a simple transpose.
// CHECK-LABEL: func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> {
func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> {
Expand Down

0 comments on commit cc2ddb1

Please sign in to comment.