From 96c8d2f89ea74401a03ce782c4a2e47004c7358c Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 22 Nov 2024 14:30:59 -0500 Subject: [PATCH 1/3] Fix output size computation for MaxPool2D for ceil_model = true. --- lib/Conversion/TorchToLinalg/Utils.cpp | 33 +++++++++++++++++-- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 8 +++-- projects/pt1/e2e_testing/xfail_sets.py | 4 +++ .../torch_mlir_e2e_test/test_suite/pooling.py | 29 ++++++++++++++++ 4 files changed, 69 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index cf41bbcd711b..363b846aae04 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -116,6 +117,34 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, else division = b.createOrFold(loc, dividend, strideInt); Value out = b.createOrFold(loc, division, c1); + + if (ceilMode) { + Value outMinusOneTimesStride = + b.createOrFold(loc, division, strideInt); + Value inAddLeftPadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), paddingInt); + + auto reduceOutputDim = + b.createOrFold(loc, arith::CmpIPredicate::uge, + outMinusOneTimesStride, inAddLeftPadding); + + // Emit 'then' region of 'scf.if' + auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { + opBuilder.create(loc, division); + }; + + // Emit 'else' region of 'scf.if' + auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { + opBuilder.create(loc, out); + }; + + // Emit 'scf.if' op + auto ifOp = b.create(loc, reduceOutputDim, emitThenRegion, + emitElseRegion); + + return castIntToIndex(b, loc, ifOp.getResult(0)); + } + return castIntToIndex(b, loc, out); } @@ -527,8 +556,8 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, Type elementType) { auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { - Value elem = - convertScalarToDtype(builder, loc, payloadArgs[0], elementType); + Value elem = mlir::torch::Torch::convertScalarToDtype( + builder, loc, payloadArgs[0], elementType); builder.create(loc, elem); }; return torch_to_linalg::createElementwiseLinalgGeneric( diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index be51712a35de..1c2f7d6f2a11 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5398,9 +5398,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { } else { int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; - if (ceilMode && (dimSize % stride != 0)) - return dimSize / stride + 2; - return dimSize / stride + 1; + int64_t outputDim = dimSize / stride + 1; + if (ceilMode && (dimSize % stride != 0) && + (outputDim * stride < inputDim + padBefore)) + outputDim++; + return outputDim; } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bb8f3a029b1d..a0997dc0be10 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -735,6 +735,7 @@ "LenStrModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", "MaxPool2dWithIndicesBackwardStatic3DModule_basic", @@ -2255,6 +2256,7 @@ "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dStaticModule_basic", "MeanModule_basic", "MmDagModule_basic", @@ -3009,6 +3011,7 @@ "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", @@ -4492,6 +4495,7 @@ "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 84e0e2eb9cf5..e2eaa4cfd0fe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) +class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp2d = torch.nn.MaxPool2d( + kernel_size=6, + stride=6, + padding=3, + dilation=1, + ceil_mode=True, + ) + + @export + @annotate_args( + [ + None, + ([2, 6, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp2d(x) + + +@register_test_case( + module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule() +) +def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0)) + + # ============================================================================== From 36fe2e2db5c61d9c8b31d47601110ece1ee4b3c5 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Mon, 2 Dec 2024 07:58:27 -0500 Subject: [PATCH 2/3] Update xfail_sets.py --- projects/pt1/e2e_testing/xfail_sets.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a0997dc0be10..1dce55f06158 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3011,7 +3011,6 @@ "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", - "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", @@ -3383,6 +3382,13 @@ "ScaledDotProductAttentionBoolMaskModule_basic", } +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } + if torch_version_for_comparison() < version.parse("2.4.0.dev"): STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { "AtenIntMM_basic", @@ -4495,7 +4501,6 @@ "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", - "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", @@ -4936,3 +4941,10 @@ "_LogSoftmaxModule_basic", "_SoftmaxModule_basic", } + +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_TOSA_XFAIL_SET = ONNX_TOSA_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } From 6efc4a8bb4149576e385c65cc86c7efc2b5d6774 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Thu, 5 Dec 2024 12:10:31 -0500 Subject: [PATCH 3/3] Use arith.select instead of scf.if --- lib/Conversion/TorchToLinalg/Utils.cpp | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 363b846aae04..98dbc1957892 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -124,25 +123,13 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value inAddLeftPadding = b.createOrFold( loc, castIndexToInt64(b, loc, in), paddingInt); - auto reduceOutputDim = + auto reduceOutputDimCond = b.createOrFold(loc, arith::CmpIPredicate::uge, outMinusOneTimesStride, inAddLeftPadding); - // Emit 'then' region of 'scf.if' - auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { - opBuilder.create(loc, division); - }; - - // Emit 'else' region of 'scf.if' - auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { - opBuilder.create(loc, out); - }; - - // Emit 'scf.if' op - auto ifOp = b.create(loc, reduceOutputDim, emitThenRegion, - emitElseRegion); - - return castIntToIndex(b, loc, ifOp.getResult(0)); + auto reducedDim = b.createOrFold(loc, reduceOutputDimCond, + division, out); + return castIntToIndex(b, loc, reducedDim); } return castIntToIndex(b, loc, out); @@ -556,8 +543,8 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, Type elementType) { auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { - Value elem = mlir::torch::Torch::convertScalarToDtype( - builder, loc, payloadArgs[0], elementType); + Value elem = + convertScalarToDtype(builder, loc, payloadArgs[0], elementType); builder.create(loc, elem); }; return torch_to_linalg::createElementwiseLinalgGeneric(