From aeedf2c22b35234dff81aa0e873fe1b9eaa8f571 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 25 Feb 2020 18:08:57 +0900 Subject: [PATCH 1/7] Lower MaxPoolSingleOutOp to Krnl dialect --- .../onnx_to_krnl/convert_onnx_to_krnl.cpp | 2 + .../rewrite_patterns/nn/pooling.inc | 208 ++++++++++++++++++ test/backend/test.py | 6 + 3 files changed, 216 insertions(+) create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp index 84d4be8a05..8610398cc1 100644 --- a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -405,6 +405,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Neural network #include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc" #include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc" //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. @@ -525,6 +526,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); + populateLoweringONNXPoolingOpPattern(patterns, &getContext()); // Entry point patterns.insert(&getContext()); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc new file mode 100644 index 0000000000..c87da8abfd --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc @@ -0,0 +1,208 @@ +//===----- pool.inc - Lowering Pooling Ops --------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Pooling Operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +// Identity values +template <> +float getIdentityValue(){ + return (float)-std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::min(); +} + +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; +} + +struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { + ONNXMaxPoolSingleOutOpLowering(MLIRContext *ctx) + : ConversionPattern( + mlir::ONNXMaxPoolSingleOutOp::getOperationName(), 1, ctx) {} + + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + // Match + ONNXMaxPoolSingleOutOp poolOp = llvm::dyn_cast(op); + + // Read kernel_shape attribute + SmallVector kernelShape; + auto kernelShapeAttribute = poolOp.kernel_shapeAttr(); + if (kernelShapeAttribute) + for (auto dim : kernelShapeAttribute.getValue()) + kernelShape.emplace_back(dim.cast().getInt()); + + // Read strides attribute + SmallVector strides; + auto stridesAttribute = poolOp.stridesAttr(); + if (stridesAttribute) + for (auto stride : stridesAttribute.getValue()) + strides.emplace_back(stride.cast().getInt()); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {operands[0]}); + + auto resultShape = memRefType.getShape(); + auto resultElementType = memRefType.getElementType(); + auto &inputOperand = operands[0]; + auto inputShape = inputOperand.getType().cast().getShape(); + + // R = MaxPool(D) + // + // The input/output shapes will look like this: + // + // D (NxCxHxW) -> R (NxCxRHxRW) + // + // The loop nest will look as follows: + // + // strides = [s1, s2] + // + // for n = 0 .. N: + // for c = 0 .. C: + // for r1 = 0 .. RH: + // for r2 = 0 .. RW: + // R[n][c][r1][r2] = 0; + // for k1 = 0 .. KH: + // for k2 = 0 .. KW: + // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; + // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); + // + // Naming: + // n, c: outer loop nest indices + // r1, r2: spatial loop nest indices + // k1, k2: inner loop nest indices + // + // TODO: handle padding. + // + + // 1. Define outer loops and emit empty optimization block: + int64_t nOuterLoops = 2; + BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops); + outerLoops.createDefineAndOptimizeOp(); + int nIndex = outerLoops.pushBounds(0, inputOperand, 0); + int cIndex = outerLoops.pushBounds(0, inputOperand, 1); + // Outer loop iteration + outerLoops.createIterateOp(); + rewriter.setInsertionPointToStart(outerLoops.getIterateBlock()); + { + // 2. Emit the body of the outer loop nest. + + // 2.1 Define spatial loops + int64_t nSpatialLoops = resultShape.size() - 2; + BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops); + spatialLoops.createDefineAndOptimizeOp(); + for (int i = 2; i < resultShape.size(); ++i) + spatialLoops.pushBounds(0, alloc, i); + + // 2.2 Emit loop nest over output spatial dimensions. + spatialLoops.createIterateOp(); + rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock()); + { + // 3. Emit the body of the spatial loop nest. + // 3.1 Emit: R[n][c][r1][r2] = 0; + SmallVector resultIndices; + // n + resultIndices.emplace_back(outerLoops.getInductionVar(nIndex)); + // c + resultIndices.emplace_back(outerLoops.getInductionVar(cIndex)); + // rX + for (auto arg : spatialLoops.getIterateBlock()->getArguments()) + resultIndices.emplace_back(arg); + // Store initializer value into output location. + Value identity; + if (resultElementType.isa()) { + identity = rewriter.create( + loc, FloatAttr::get(resultElementType, + getIdentityValue())); + } else if (resultElementType.isa()) { + identity = rewriter.create( + loc, IntegerAttr::get(resultElementType, + getIdentityValue())); + } else { + emitError(loc, "unsupported element type"); + } + rewriter.create(loc, identity, alloc, resultIndices); + + // 3.2 Define inner loops. + int64_t nInnerLoops = kernelShape.size(); + BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops); + innerLoops.createDefineAndOptimizeOp(); + // for Kx = 0 .. KX + for (int i = 0; i < nInnerLoops; ++i) + innerLoops.pushBounds(0, kernelShape[i]); + + // 3.4 Emit inner loop nest. + innerLoops.createIterateOp(); + rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); + + { + // 4. Emit inner loop body + // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; + // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); + + // 4.1 Prepare indices for accesing the data tensor. + SmallVector dataIndices; + // n + dataIndices.emplace_back(outerLoops.getInductionVar(nIndex)); + // c + dataIndices.emplace_back(outerLoops.getInductionVar(cIndex)); + // sX * rX + kX + for (int i = 0; i < kernelShape.size(); ++i) { + Value spatialIndex = spatialLoops.getInductionVar(i); + // If strides are present then emit the correct access index. + if (stridesAttribute && strides[i] > 1) + spatialIndex = rewriter.create(loc, + rewriter.create(loc, strides[i]), + spatialLoops.getInductionVar(i)); + dataIndices.emplace_back(rewriter.create( + loc, spatialIndex, innerLoops.getInductionVar(i))); + } + + // 4.2 Do pooling. + auto loadData = + rewriter.create(loc, inputOperand, dataIndices); + auto loadPartialResult = + rewriter.create(loc, alloc, resultIndices); + Value result = mapToLowerScalarOp( + op, resultElementType, {loadPartialResult, loadData}, rewriter); + // 4.4 Store computed value into output location. + rewriter.create(loc, result, alloc, resultIndices); + } + } + } + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXPoolingOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/test/backend/test.py b/test/backend/test.py index abd27db8de..8ca67bb17c 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -305,6 +305,12 @@ def supports_device(cls, device): "test_batchnorm_epsilon_cpu", "test_batchnorm_example_cpu", + # Pooling + "test_maxpool_1d_default_cpu", + "test_maxpool_2d_default_cpu", + "test_maxpool_2d_strides_cpu", + "test_maxpool_3d_default_cpu", + ] # Extract name of all test cases. From 480333bb0017eb73b6857affd16143b19a453253 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 25 Feb 2020 18:35:28 +0900 Subject: [PATCH 2/7] Edit comments --- .../onnx_to_krnl/rewrite_patterns/nn/pooling.inc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc index c87da8abfd..9bc3fda9e8 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc @@ -49,6 +49,9 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { if (kernelShapeAttribute) for (auto dim : kernelShapeAttribute.getValue()) kernelShape.emplace_back(dim.cast().getInt()); + else + emitError(loc, "kernel_shape is a mandatory attribute for which there is " + "no default."); // Read strides attribute SmallVector strides; @@ -87,7 +90,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // for c = 0 .. C: // for r1 = 0 .. RH: // for r2 = 0 .. RW: - // R[n][c][r1][r2] = 0; + // R[n][c][r1][r2] = negative_infinity; // for k1 = 0 .. KH: // for k2 = 0 .. KW: // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; @@ -125,7 +128,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock()); { // 3. Emit the body of the spatial loop nest. - // 3.1 Emit: R[n][c][r1][r2] = 0; + // 3.1 Emit: R[n][c][r1][r2] = negative_infinity; SmallVector resultIndices; // n resultIndices.emplace_back(outerLoops.getInductionVar(nIndex)); @@ -157,7 +160,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { for (int i = 0; i < nInnerLoops; ++i) innerLoops.pushBounds(0, kernelShape[i]); - // 3.4 Emit inner loop nest. + // 3.3 Emit inner loop nest. innerLoops.createIterateOp(); rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); @@ -191,7 +194,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { rewriter.create(loc, alloc, resultIndices); Value result = mapToLowerScalarOp( op, resultElementType, {loadPartialResult, loadData}, rewriter); - // 4.4 Store computed value into output location. + // 4.3 Store computed value into output location. rewriter.create(loc, result, alloc, resultIndices); } } From 667be1e669f4c984e2cb31d333c90c88ec0f8c06 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 26 Feb 2020 09:59:51 +0900 Subject: [PATCH 3/7] Update changes according to the new folder structure --- src/CMakeLists.txt | 1 + .../{rewrite_patterns/nn/pooling.inc => nn/pooling.cpp} | 6 +++++- src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) rename src/conversion/onnx_to_krnl/{rewrite_patterns/nn/pooling.inc => nn/pooling.cpp} (98%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b210275a40..78afdc529d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -72,6 +72,7 @@ add_library(onnf_lower_frontend conversion/onnx_to_krnl/math/softmax.cpp conversion/onnx_to_krnl/nn/conv.cpp conversion/onnx_to_krnl/nn/normalization.cpp + conversion/onnx_to_krnl/nn/pooling.cpp conversion/onnx_to_krnl/tensor/identity.cpp conversion/onnx_to_krnl/tensor/reshape.cpp conversion/onnx_to_krnl/tensor/transpose.cpp diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc b/src/conversion/onnx_to_krnl/nn/pooling.cpp similarity index 98% rename from src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc rename to src/conversion/onnx_to_krnl/nn/pooling.cpp index 9bc3fda9e8..fd39a3a509 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/pooling.inc +++ b/src/conversion/onnx_to_krnl/nn/pooling.cpp @@ -1,4 +1,4 @@ -//===----- pool.inc - Lowering Pooling Ops --------------------------------===// +//===----- pool.cpp - Lowering Pooling Ops --------------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + // Identity values template <> float getIdentityValue(){ diff --git a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp index bd22d95019..df4b65d65e 100644 --- a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp +++ b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp @@ -202,6 +202,9 @@ void populateLoweringONNXConvOpPattern( void populateLoweringONNXNormalizationOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateLoweringONNXPoolingOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + // `tensor` directory methods: void populateLoweringONNXUnsqueezeOpPattern( From e32fe13669e6e715b5845dc52e1e0008f69af587 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 26 Feb 2020 11:14:18 +0900 Subject: [PATCH 4/7] Add MLIR tests --- src/conversion/onnx_to_krnl/nn/pooling.cpp | 2 +- test/mlir/onnx/onnx_lowering.mlir | 76 ++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/conversion/onnx_to_krnl/nn/pooling.cpp b/src/conversion/onnx_to_krnl/nn/pooling.cpp index fd39a3a509..39728b76b9 100644 --- a/src/conversion/onnx_to_krnl/nn/pooling.cpp +++ b/src/conversion/onnx_to_krnl/nn/pooling.cpp @@ -1,4 +1,4 @@ -//===----- pool.cpp - Lowering Pooling Ops --------------------------------===// +//===----- pooling.cpp - Lowering Pooling Ops -----------------------------===// // // Copyright 2019 The IBM Research Authors. // diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index c35536d12f..2c4713a195 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1344,3 +1344,79 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a // CHECK: return [[RES]] : memref<10xf32> } +func @test_maxpooling_singleout_no_pad(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_maxpooling_singleout_no_pad + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> + // CHECK: [[DEF_LOOPS_0:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_0:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3) { + // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg3 = 0 to 31, [[DEF_LOOPS_1]]#1 -> %arg4 = 0 to 31) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[DEF_LOOPS_2:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_2:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_2]]#0, [[DEF_LOOPS_2]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_2]]#0, [[OPT_LOOPS_2]]#1) with ([[DEF_LOOPS_2]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_2]]#1 -> %arg6 = 0 to 2) { + // CHECK: [[H:%.+]] = addi %arg3, %arg5 : index + // CHECK: [[W:%.+]] = addi %arg4, %arg6 : index + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: } + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref<1x3x31x31xf32> +} + +func @test_maxpooling_singleout_no_pad_w_strides(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2], strides = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> + // CHECK: [[DEF_LOOPS_0:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_0:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3) { + // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg3 = 0 to 16, [[DEF_LOOPS_1]]#1 -> %arg4 = 0 to 16) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: [[DEF_LOOPS_2:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_2:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_2]]#0, [[DEF_LOOPS_2]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_2]]#0, [[OPT_LOOPS_2]]#1) with ([[DEF_LOOPS_2]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_2]]#1 -> %arg6 = 0 to 2) { + // CHECK: [[STRIDE_0:%.+]] = constant 2 : index + // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index + // CHECK: [[H:%.+]] = addi [[MUL_0]], %arg5 : index + // CHECK: [[STRIDE_1:%.+]] = constant 2 : index + // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index + // CHECK: [[W:%.+]] = addi [[MUL_1]], %arg6 : index + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: } + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref<1x3x16x16xf32> +} + From 88477bbab848be7c0d690e26f92a19a3095bfa52 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 26 Feb 2020 13:50:06 +0900 Subject: [PATCH 5/7] Support ceil_mode --- src/conversion/onnx_to_krnl/nn/pooling.cpp | 34 ++++++++++++---- test/backend/test.py | 1 + test/mlir/onnx/onnx_lowering.mlir | 46 ++++++++++++++++++++++ 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/src/conversion/onnx_to_krnl/nn/pooling.cpp b/src/conversion/onnx_to_krnl/nn/pooling.cpp index 39728b76b9..3090a6d310 100644 --- a/src/conversion/onnx_to_krnl/nn/pooling.cpp +++ b/src/conversion/onnx_to_krnl/nn/pooling.cpp @@ -64,6 +64,9 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { for (auto stride : stridesAttribute.getValue()) strides.emplace_back(stride.cast().getInt()); + // Read ceil_mode attribute + auto ceilMode = poolOp.ceil_mode().getSExtValue(); + // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; @@ -109,8 +112,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // // 1. Define outer loops and emit empty optimization block: - int64_t nOuterLoops = 2; - BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops); + int batchRank = 2; // N and C dimensions + BuildKrnlLoop outerLoops(rewriter, loc, batchRank); outerLoops.createDefineAndOptimizeOp(); int nIndex = outerLoops.pushBounds(0, inputOperand, 0); int cIndex = outerLoops.pushBounds(0, inputOperand, 1); @@ -121,10 +124,10 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // 2. Emit the body of the outer loop nest. // 2.1 Define spatial loops - int64_t nSpatialLoops = resultShape.size() - 2; + int nSpatialLoops = resultShape.size() - batchRank; BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops); spatialLoops.createDefineAndOptimizeOp(); - for (int i = 2; i < resultShape.size(); ++i) + for (int i = batchRank; i < resultShape.size(); ++i) spatialLoops.pushBounds(0, alloc, i); // 2.2 Emit loop nest over output spatial dimensions. @@ -157,7 +160,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { rewriter.create(loc, identity, alloc, resultIndices); // 3.2 Define inner loops. - int64_t nInnerLoops = kernelShape.size(); + int nInnerLoops = kernelShape.size(); BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops); innerLoops.createDefineAndOptimizeOp(); // for Kx = 0 .. KX @@ -183,12 +186,27 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { for (int i = 0; i < kernelShape.size(); ++i) { Value spatialIndex = spatialLoops.getInductionVar(i); // If strides are present then emit the correct access index. - if (stridesAttribute && strides[i] > 1) + if (stridesAttribute && strides[i] > 1) { spatialIndex = rewriter.create(loc, rewriter.create(loc, strides[i]), spatialLoops.getInductionVar(i)); - dataIndices.emplace_back(rewriter.create( - loc, spatialIndex, innerLoops.getInductionVar(i))); + } + spatialIndex = rewriter.create( + loc, spatialIndex, innerLoops.getInductionVar(i)); + // If ceil mode is enabled, then the calculated access index may + // exceed its dimension. In such a case, we will use the maximum + // index, which causes multiple visits to the element of the + // maximum index. + // TODO: Avoid multiple visits. + if (ceilMode) { + auto inputIndex = rewriter.create( + loc, inputShape[batchRank + i] - 1); + auto greaterCondition = rewriter.create( + loc, CmpIPredicate::sgt, spatialIndex, inputIndex); + spatialIndex = rewriter.create( + loc, greaterCondition, inputIndex, spatialIndex); + } + dataIndices.emplace_back(spatialIndex); } // 4.2 Do pooling. diff --git a/test/backend/test.py b/test/backend/test.py index 8ca67bb17c..818332f426 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -307,6 +307,7 @@ def supports_device(cls, device): # Pooling "test_maxpool_1d_default_cpu", + "test_maxpool_2d_ceil_cpu", "test_maxpool_2d_default_cpu", "test_maxpool_2d_strides_cpu", "test_maxpool_3d_default_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 2c4713a195..c8476328a6 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1420,3 +1420,49 @@ func @test_maxpooling_singleout_no_pad_w_strides(%arg0 : tensor<1x3x32x32xf32>) // CHECK: return [[RES]] : memref<1x3x16x16xf32> } +func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> + // CHECK: [[DEF_LOOPS_0:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_0:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3) { + // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg3 = 0 to 16, [[DEF_LOOPS_1]]#1 -> %arg4 = 0 to 16) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: [[DEF_LOOPS_2:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_2:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_2]]#0, [[DEF_LOOPS_2]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_2]]#0, [[OPT_LOOPS_2]]#1) with ([[DEF_LOOPS_2]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_2]]#1 -> %arg6 = 0 to 3) { + // CHECK: [[STRIDE_0:%.+]] = constant 2 : index + // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index + // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index + // CHECK: [[INPUT_INDEX_0:%.+]] = constant 31 : index + // CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index + // CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index + // CHECK: [[STRIDE_1:%.+]] = constant 2 : index + // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index + // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index + // CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index + // CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index + // CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: } + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref<1x3x16x16xf32> +} + From fc793134c57993fece0cde7fc9b4625eaabef613 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 27 Feb 2020 10:41:45 +0900 Subject: [PATCH 6/7] Merge the first two krnl loops into one krnl loop; remove attribute checks --- src/conversion/onnx_to_krnl/nn/pooling.cpp | 207 +++++++++------------ test/mlir/onnx/onnx_lowering.mlir | 136 ++++++-------- 2 files changed, 147 insertions(+), 196 deletions(-) diff --git a/src/conversion/onnx_to_krnl/nn/pooling.cpp b/src/conversion/onnx_to_krnl/nn/pooling.cpp index 3090a6d310..fa56bd3144 100644 --- a/src/conversion/onnx_to_krnl/nn/pooling.cpp +++ b/src/conversion/onnx_to_krnl/nn/pooling.cpp @@ -14,12 +14,12 @@ using namespace mlir; // Identity values template <> -float getIdentityValue(){ +float getIdentityValue() { return (float)-std::numeric_limits::infinity(); } template <> -int getIdentityValue(){ +int getIdentityValue() { return std::numeric_limits::min(); } @@ -50,19 +50,14 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // Read kernel_shape attribute SmallVector kernelShape; auto kernelShapeAttribute = poolOp.kernel_shapeAttr(); - if (kernelShapeAttribute) - for (auto dim : kernelShapeAttribute.getValue()) - kernelShape.emplace_back(dim.cast().getInt()); - else - emitError(loc, "kernel_shape is a mandatory attribute for which there is " - "no default."); + for (auto dim : kernelShapeAttribute.getValue()) + kernelShape.emplace_back(dim.cast().getInt()); // Read strides attribute SmallVector strides; auto stridesAttribute = poolOp.stridesAttr(); - if (stridesAttribute) - for (auto stride : stridesAttribute.getValue()) - strides.emplace_back(stride.cast().getInt()); + for (auto stride : stridesAttribute.getValue()) + strides.emplace_back(stride.cast().getInt()); // Read ceil_mode attribute auto ceilMode = poolOp.ceil_mode().getSExtValue(); @@ -98,127 +93,101 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // for r1 = 0 .. RH: // for r2 = 0 .. RW: // R[n][c][r1][r2] = negative_infinity; - // for k1 = 0 .. KH: - // for k2 = 0 .. KW: - // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; - // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); + // for k1 = 0 .. KH: + // for k2 = 0 .. KW: + // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; + // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); // // Naming: - // n, c: outer loop nest indices - // r1, r2: spatial loop nest indices + // n, c, r1, r2: outer loop nest indices // k1, k2: inner loop nest indices // // TODO: handle padding. // - // 1. Define outer loops and emit empty optimization block: - int batchRank = 2; // N and C dimensions - BuildKrnlLoop outerLoops(rewriter, loc, batchRank); - outerLoops.createDefineAndOptimizeOp(); - int nIndex = outerLoops.pushBounds(0, inputOperand, 0); - int cIndex = outerLoops.pushBounds(0, inputOperand, 1); - // Outer loop iteration - outerLoops.createIterateOp(); + // 1. Define outer loops and emit empty optimization block. + auto nOuterLoops = resultShape.size(); + BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops); + outerLoops.createDefineOptimizeAndIterateOp(alloc); + rewriter.setInsertionPointToStart(outerLoops.getIterateBlock()); { // 2. Emit the body of the outer loop nest. - - // 2.1 Define spatial loops - int nSpatialLoops = resultShape.size() - batchRank; - BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops); - spatialLoops.createDefineAndOptimizeOp(); - for (int i = batchRank; i < resultShape.size(); ++i) - spatialLoops.pushBounds(0, alloc, i); - - // 2.2 Emit loop nest over output spatial dimensions. - spatialLoops.createIterateOp(); - rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock()); + SmallVector resultIndices; + for (int i = 0; i < nOuterLoops; ++i) + resultIndices.emplace_back(outerLoops.getInductionVar(i)); + + // 2.1 Emit: R[n][c][r1][r2] = negative_infinity; + Value identity; + if (resultElementType.isa()) { + identity = rewriter.create( + loc, FloatAttr::get(resultElementType, + getIdentityValue())); + } else if (resultElementType.isa()) { + identity = rewriter.create( + loc, IntegerAttr::get(resultElementType, + getIdentityValue())); + } else { + emitError(loc, "unsupported element type"); + } + rewriter.create(loc, identity, alloc, resultIndices); + + // 2.2 Define inner loops. + int nInnerLoops = kernelShape.size(); + BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops); + innerLoops.createDefineAndOptimizeOp(); + // for Kx = 0 .. KX + for (int i = 0; i < nInnerLoops; ++i) + innerLoops.pushBounds(0, kernelShape[i]); + + // 2.3 Emit inner loop nest. + innerLoops.createIterateOp(); + rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); { - // 3. Emit the body of the spatial loop nest. - // 3.1 Emit: R[n][c][r1][r2] = negative_infinity; - SmallVector resultIndices; - // n - resultIndices.emplace_back(outerLoops.getInductionVar(nIndex)); - // c - resultIndices.emplace_back(outerLoops.getInductionVar(cIndex)); - // rX - for (auto arg : spatialLoops.getIterateBlock()->getArguments()) - resultIndices.emplace_back(arg); - // Store initializer value into output location. - Value identity; - if (resultElementType.isa()) { - identity = rewriter.create( - loc, FloatAttr::get(resultElementType, - getIdentityValue())); - } else if (resultElementType.isa()) { - identity = rewriter.create( - loc, IntegerAttr::get(resultElementType, - getIdentityValue())); - } else { - emitError(loc, "unsupported element type"); - } - rewriter.create(loc, identity, alloc, resultIndices); - - // 3.2 Define inner loops. - int nInnerLoops = kernelShape.size(); - BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops); - innerLoops.createDefineAndOptimizeOp(); - // for Kx = 0 .. KX - for (int i = 0; i < nInnerLoops; ++i) - innerLoops.pushBounds(0, kernelShape[i]); - - // 3.3 Emit inner loop nest. - innerLoops.createIterateOp(); - rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); - - { - // 4. Emit inner loop body - // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; - // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); - - // 4.1 Prepare indices for accesing the data tensor. - SmallVector dataIndices; - // n - dataIndices.emplace_back(outerLoops.getInductionVar(nIndex)); - // c - dataIndices.emplace_back(outerLoops.getInductionVar(cIndex)); - // sX * rX + kX - for (int i = 0; i < kernelShape.size(); ++i) { - Value spatialIndex = spatialLoops.getInductionVar(i); - // If strides are present then emit the correct access index. - if (stridesAttribute && strides[i] > 1) { - spatialIndex = rewriter.create(loc, - rewriter.create(loc, strides[i]), - spatialLoops.getInductionVar(i)); - } - spatialIndex = rewriter.create( - loc, spatialIndex, innerLoops.getInductionVar(i)); - // If ceil mode is enabled, then the calculated access index may - // exceed its dimension. In such a case, we will use the maximum - // index, which causes multiple visits to the element of the - // maximum index. - // TODO: Avoid multiple visits. - if (ceilMode) { - auto inputIndex = rewriter.create( - loc, inputShape[batchRank + i] - 1); - auto greaterCondition = rewriter.create( - loc, CmpIPredicate::sgt, spatialIndex, inputIndex); - spatialIndex = rewriter.create( - loc, greaterCondition, inputIndex, spatialIndex); - } - dataIndices.emplace_back(spatialIndex); + // 3. Emit inner loop body + // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; + // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); + + // 3.1 Prepare indices for accesing the data tensor. + SmallVector dataIndices; + // Batch indices: n, c + int batchRank = 2; + for (int i = 0; i < batchRank; ++i) + dataIndices.emplace_back(outerLoops.getInductionVar(i)); + // Spatial indices: sX * rX + kX + for (int i = batchRank; i < nOuterLoops; ++i) { + Value spatialIndex = outerLoops.getInductionVar(i); + // If strides are present then emit the correct access index. + if (stridesAttribute && strides[i - batchRank] > 1) { + spatialIndex = rewriter.create(loc, + rewriter.create(loc, strides[i - batchRank]), + outerLoops.getInductionVar(i)); } - - // 4.2 Do pooling. - auto loadData = - rewriter.create(loc, inputOperand, dataIndices); - auto loadPartialResult = - rewriter.create(loc, alloc, resultIndices); - Value result = mapToLowerScalarOp( - op, resultElementType, {loadPartialResult, loadData}, rewriter); - // 4.3 Store computed value into output location. - rewriter.create(loc, result, alloc, resultIndices); + spatialIndex = rewriter.create( + loc, spatialIndex, innerLoops.getInductionVar(i - batchRank)); + // If ceil mode is enabled, then the calculated access index may + // exceed its dimension. In such a case, we will use the maximum + // index, which causes multiple visits to the element of the + // maximum index. + // TODO: Avoid multiple visits. + if (ceilMode) { + auto inputIndex = + rewriter.create(loc, inputShape[i] - 1); + auto greaterCondition = rewriter.create( + loc, CmpIPredicate::sgt, spatialIndex, inputIndex); + spatialIndex = rewriter.create( + loc, greaterCondition, inputIndex, spatialIndex); + } + dataIndices.emplace_back(spatialIndex); } + + // 3.2 Do pooling. + auto loadData = rewriter.create(loc, inputOperand, dataIndices); + auto loadPartialResult = + rewriter.create(loc, alloc, resultIndices); + Value result = mapToLowerScalarOp( + op, resultElementType, {loadPartialResult, loadData}, rewriter); + rewriter.create(loc, result, alloc, resultIndices); } } rewriter.replaceOp(op, alloc); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index e009801194..209c8ff9ce 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1350,31 +1350,25 @@ func @test_maxpooling_singleout_no_pad(%arg0 : tensor<1x3x32x32xf32>) -> tensor< // CHECK-LABEL: test_maxpooling_singleout_no_pad // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> - // CHECK: [[DEF_LOOPS_0:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_0:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3) { + // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 31, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 31) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg3 = 0 to 31, [[DEF_LOOPS_1]]#1 -> %arg4 = 0 to 31) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: [[DEF_LOOPS_2:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_2:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_2]]#0, [[DEF_LOOPS_2]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_2]]#0, [[OPT_LOOPS_2]]#1) with ([[DEF_LOOPS_2]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_2]]#1 -> %arg6 = 0 to 2) { - // CHECK: [[H:%.+]] = addi %arg3, %arg5 : index - // CHECK: [[W:%.+]] = addi %arg4, %arg6 : index - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: } + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) { + // CHECK: [[H:%.+]] = addi %arg3, %arg5 : index + // CHECK: [[W:%.+]] = addi %arg4, %arg6 : index + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> // CHECK: } // CHECK: } // CHECK: return [[RES]] : memref<1x3x31x31xf32> @@ -1386,35 +1380,29 @@ func @test_maxpooling_singleout_no_pad_w_strides(%arg0 : tensor<1x3x32x32xf32>) // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_0:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_0:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3) { + // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg3 = 0 to 16, [[DEF_LOOPS_1]]#1 -> %arg4 = 0 to 16) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_2:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_2:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_2]]#0, [[DEF_LOOPS_2]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_2]]#0, [[OPT_LOOPS_2]]#1) with ([[DEF_LOOPS_2]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_2]]#1 -> %arg6 = 0 to 2) { - // CHECK: [[STRIDE_0:%.+]] = constant 2 : index - // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index - // CHECK: [[H:%.+]] = addi [[MUL_0]], %arg5 : index - // CHECK: [[STRIDE_1:%.+]] = constant 2 : index - // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index - // CHECK: [[W:%.+]] = addi [[MUL_1]], %arg6 : index - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: } + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) { + // CHECK: [[STRIDE_0:%.+]] = constant 2 : index + // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index + // CHECK: [[H:%.+]] = addi [[MUL_0]], %arg5 : index + // CHECK: [[STRIDE_1:%.+]] = constant 2 : index + // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index + // CHECK: [[W:%.+]] = addi [[MUL_1]], %arg6 : index + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> // CHECK: } // CHECK: } // CHECK: return [[RES]] : memref<1x3x16x16xf32> @@ -1426,41 +1414,35 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_0:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_0:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3) { + // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg3 = 0 to 16, [[DEF_LOOPS_1]]#1 -> %arg4 = 0 to 16) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_2:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_2:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_2]]#0, [[DEF_LOOPS_2]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_2]]#0, [[OPT_LOOPS_2]]#1) with ([[DEF_LOOPS_2]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_2]]#1 -> %arg6 = 0 to 3) { - // CHECK: [[STRIDE_0:%.+]] = constant 2 : index - // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index - // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index - // CHECK: [[INPUT_INDEX_0:%.+]] = constant 31 : index - // CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index - // CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index - // CHECK: [[STRIDE_1:%.+]] = constant 2 : index - // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index - // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index - // CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index - // CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index - // CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: } + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { + // CHECK: [[STRIDE_0:%.+]] = constant 2 : index + // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index + // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index + // CHECK: [[INPUT_INDEX_0:%.+]] = constant 31 : index + // CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index + // CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index + // CHECK: [[STRIDE_1:%.+]] = constant 2 : index + // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index + // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index + // CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index + // CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index + // CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> + // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> // CHECK: } // CHECK: } // CHECK: return [[RES]] : memref<1x3x16x16xf32> From ffd8d1fff78ca3e36aa0f6ae5b489f7f41821506 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 27 Feb 2020 14:53:10 +0900 Subject: [PATCH 7/7] Dynamically allocate memory for the result if the result has unknown dimensions --- src/conversion/onnx_to_krnl/nn/pooling.cpp | 114 +++++++++++++++++++-- test/mlir/onnx/onnx_lowering.mlir | 62 +++++++++++ 2 files changed, 165 insertions(+), 11 deletions(-) diff --git a/src/conversion/onnx_to_krnl/nn/pooling.cpp b/src/conversion/onnx_to_krnl/nn/pooling.cpp index fa56bd3144..0015b8ad9b 100644 --- a/src/conversion/onnx_to_krnl/nn/pooling.cpp +++ b/src/conversion/onnx_to_krnl/nn/pooling.cpp @@ -62,21 +62,107 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // Read ceil_mode attribute auto ceilMode = poolOp.ceil_mode().getSExtValue(); - // Insert an allocation and deallocation for the result of this operation. + // Read pads attribute + SmallVector pads; + auto padsAttribute = poolOp.padsAttr(); + for (auto pad : padsAttribute.getValue()) + pads.emplace_back(pad.cast().getInt()); + + // Read dilations attribute + SmallVector dilations; + auto dilationsAttribute = poolOp.dilationsAttr(); + for (auto dilation : dilationsAttribute.getValue()) + dilations.emplace_back(dilation.cast().getInt()); + + // Type information about the input and result of this operation. + auto &inputOperand = operands[0]; + auto inputShape = inputOperand.getType().cast().getShape(); auto memRefType = convertToMemRefType(*op->result_type_begin()); + auto resultShape = memRefType.getShape(); + auto resultElementType = memRefType.getElementType(); + + // Batch indices: N and C dimensions + int batchRank = 2; + + // Insert an allocation and deallocation for the result of this operation. Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc, {operands[0]}); + else { + // Compute dimensions of the result of this operation. + SmallVector allocOperands; + for (int i = 0; i < batchRank; ++i) { + if (resultShape[i] < 0) { + auto dim = rewriter.create(loc, inputOperand, i); + allocOperands.emplace_back(dim); + } + } - auto resultShape = memRefType.getShape(); - auto resultElementType = memRefType.getElementType(); - auto &inputOperand = operands[0]; - auto inputShape = inputOperand.getType().cast().getShape(); + Value zero, one; + if (ceilMode) { + zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + } + one = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + + int spatialRank = resultShape.size() - batchRank; + for (int i = batchRank; i < resultShape.size(); ++i) { + if (resultShape[i] < 0) { + // dim = + // let numerator = (input + pad - (kernel - 1) * dilation + 1) + // in let denomitor = stride + // in + // if (ceilMode) + // ceil(numerator / denominator) + 1 + // else + // floor(numerator / denominator) + 1 + int spatialIndex = i - batchRank; + + // numerator = (input + pad - (kernel - 1) * dilation + 1) + auto inputDim = rewriter.create(loc, inputOperand, i); + auto inputVal = rewriter.create( + loc, inputDim, rewriter.getIntegerType(64)); + int64_t padKernelDilation = + (pads[spatialIndex] + pads[spatialIndex + spatialRank]) - + (kernelShape[spatialIndex] - 1) * dilations[spatialIndex] + 1; + auto padKernelDilationVal = rewriter.create( + loc, rewriter.getIntegerAttr( + rewriter.getIntegerType(64), padKernelDilation)); + auto numeratorVal = + rewriter.create(loc, inputVal, padKernelDilationVal); + // denominator + auto denominatorVal = rewriter.create( + loc, rewriter.getIntegerAttr( + rewriter.getIntegerType(64), strides[spatialIndex])); + + // numerator / denominator + Value dimVal = + rewriter.create(loc, numeratorVal, denominatorVal); + + if (ceilMode) { + auto remainder = rewriter.create( + loc, numeratorVal, denominatorVal); + auto isZero = rewriter.create( + loc, CmpIPredicate::eq, remainder, zero); + auto dimPlusOne = rewriter.create(loc, dimVal, one); + dimVal = rewriter.create(loc, isZero, dimVal, dimPlusOne); + } + + dimVal = rewriter.create(loc, dimVal, one); + allocOperands.emplace_back(rewriter.create( + loc, dimVal, rewriter.getIndexType())); + } + } + alloc = rewriter.create(loc, memRefType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } // R = MaxPool(D) // @@ -151,7 +237,6 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // 3.1 Prepare indices for accesing the data tensor. SmallVector dataIndices; // Batch indices: n, c - int batchRank = 2; for (int i = 0; i < batchRank; ++i) dataIndices.emplace_back(outerLoops.getInductionVar(i)); // Spatial indices: sX * rX + kX @@ -171,8 +256,15 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // maximum index. // TODO: Avoid multiple visits. if (ceilMode) { - auto inputIndex = - rewriter.create(loc, inputShape[i] - 1); + Value inputIndex; + if (inputShape[i] < 0) { + Value inputDim = rewriter.create(loc, inputOperand, i); + Value one = rewriter.create(loc, 1); + inputIndex = rewriter.create(loc, inputDim, one); + } else { + inputIndex = + rewriter.create(loc, inputShape[i] - 1); + } auto greaterCondition = rewriter.create( loc, CmpIPredicate::sgt, spatialIndex, inputIndex); spatialIndex = rewriter.create( diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 209c8ff9ce..149d72465b 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1448,3 +1448,65 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x // CHECK: return [[RES]] : memref<1x3x16x16xf32> } +func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims + + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[ZERO:%.+]] = constant 0 : i64 + // CHECK: [[ONE:%.+]] = constant 1 : i64 + // CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref + // CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64 + // CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -1 : i64 + // CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64 + // CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64 + // CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 + // CHECK: [[REMAINDER:%.+]] = remi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 + // CHECK: [[IS_ZERO:%.+]] = cmpi "eq", [[REMAINDER]], [[ZERO]] : i64 + // CHECK: [[DIV_PLUS_ONE:%.+]] = addi [[DIV]], [[ONE]] : i64 + // CHECK: [[SELECT:%.+]] = select [[IS_ZERO]], [[DIV]], [[DIV_PLUS_ONE]] : i64 + // CHECK: [[SELECT_PLUS_ONE:%.+]] = addi [[SELECT]], [[ONE]] : i64 + // CHECK: [[DIM_1_FINAL:%.+]] = index_cast [[SELECT_PLUS_ONE]] : i64 to index + // CHECK: [[RES:%.+]] = alloc([[DIM_0]], [[DIM_1_FINAL]]) : memref + + // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[DIM_3:%.+]] = dim [[RES]], 2 : memref + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to [[DIM_3]], [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref + // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { + // CHECK: [[STRIDE_0:%.+]] = constant 2 : index + // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index + // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index + // CHECK: [[DIM_0_0:%.+]] = dim %arg0, 2 : memref + // CHECK: [[ONE_INDEX:%.+]] = constant 1 : index + // CHECK: [[INPUT_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index + // CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index + // CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index + + // CHECK: [[STRIDE_1:%.+]] = constant 2 : index + // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index + // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index + // CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index + // CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index + // CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index + + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref + // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref +}