Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower MaxPoolSingleOutOp to Krnl dialect #1

Merged
merged 11 commits into from
Mar 4, 2020
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ add_library(onnf_lower_frontend
conversion/onnx_to_krnl/math/softmax.cpp
conversion/onnx_to_krnl/nn/conv.cpp
conversion/onnx_to_krnl/nn/normalization.cpp
conversion/onnx_to_krnl/nn/pooling.cpp
conversion/onnx_to_krnl/tensor/identity.cpp
conversion/onnx_to_krnl/tensor/reshape.cpp
conversion/onnx_to_krnl/tensor/transpose.cpp
Expand Down
1 change: 1 addition & 0 deletions src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
// Neural network
populateLoweringONNXConvOpPattern(patterns, &getContext());
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
populateLoweringONNXPoolingOpPattern(patterns, &getContext());
// Entry point
patterns.insert<ONNXEntryPointLowering>(&getContext());

Expand Down
294 changes: 294 additions & 0 deletions src/conversion/onnx_to_krnl/nn/pooling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
//===----- pooling.cpp - Lowering Pooling Ops -----------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Pooling Operators to Krnl dialect.
//
//===----------------------------------------------------------------------===//

#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"

using namespace mlir;

// Identity values
template <>
float getIdentityValue<float, ONNXMaxPoolSingleOutOp>() {
return (float)-std::numeric_limits<float>::infinity();
}

template <>
int getIdentityValue<int, ONNXMaxPoolSingleOutOp>() {
return std::numeric_limits<int>::min();
}

template <>
Value mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Value lhs = operands[0];
Value rhs = operands[1];
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result;
}

struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
ONNXMaxPoolSingleOutOpLowering(MLIRContext *ctx)
: ConversionPattern(
mlir::ONNXMaxPoolSingleOutOp::getOperationName(), 1, ctx) {}

PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();

// Match
ONNXMaxPoolSingleOutOp poolOp = llvm::dyn_cast<ONNXMaxPoolSingleOutOp>(op);

// Read kernel_shape attribute
SmallVector<int, 4> kernelShape;
auto kernelShapeAttribute = poolOp.kernel_shapeAttr();
for (auto dim : kernelShapeAttribute.getValue())
kernelShape.emplace_back(dim.cast<IntegerAttr>().getInt());

// Read strides attribute
SmallVector<int, 4> strides;
auto stridesAttribute = poolOp.stridesAttr();
for (auto stride : stridesAttribute.getValue())
strides.emplace_back(stride.cast<IntegerAttr>().getInt());

// Read ceil_mode attribute
auto ceilMode = poolOp.ceil_mode().getSExtValue();

// Read pads attribute
SmallVector<int, 4> pads;
auto padsAttribute = poolOp.padsAttr();
for (auto pad : padsAttribute.getValue())
pads.emplace_back(pad.cast<IntegerAttr>().getInt());

// Read dilations attribute
SmallVector<int, 4> dilations;
auto dilationsAttribute = poolOp.dilationsAttr();
for (auto dilation : dilationsAttribute.getValue())
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());

// Type information about the input and result of this operation.
auto &inputOperand = operands[0];
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto resultShape = memRefType.getShape();
auto resultElementType = memRefType.getElementType();

// Batch indices: N and C dimensions
int batchRank = 2;

// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);

if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else {
// Compute dimensions of the result of this operation.
SmallVector<Value, 2> allocOperands;
for (int i = 0; i < batchRank; ++i) {
if (resultShape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, inputOperand, i);
allocOperands.emplace_back(dim);
}
}

Value zero, one;
if (ceilMode) {
zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
}
one = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));

int spatialRank = resultShape.size() - batchRank;
for (int i = batchRank; i < resultShape.size(); ++i) {
if (resultShape[i] < 0) {
// dim =
// let numerator = (input + pad - (kernel - 1) * dilation + 1)
// in let denomitor = stride
// in
// if (ceilMode)
// ceil(numerator / denominator) + 1
// else
// floor(numerator / denominator) + 1
int spatialIndex = i - batchRank;

// numerator = (input + pad - (kernel - 1) * dilation + 1)
auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
auto inputVal = rewriter.create<IndexCastOp>(
loc, inputDim, rewriter.getIntegerType(64));
int64_t padKernelDilation =
(pads[spatialIndex] + pads[spatialIndex + spatialRank]) -
(kernelShape[spatialIndex] - 1) * dilations[spatialIndex] + 1;
auto padKernelDilationVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(
rewriter.getIntegerType(64), padKernelDilation));
auto numeratorVal =
rewriter.create<AddIOp>(loc, inputVal, padKernelDilationVal);
// denominator
auto denominatorVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(
rewriter.getIntegerType(64), strides[spatialIndex]));

// numerator / denominator
Value dimVal =
rewriter.create<SignedDivIOp>(loc, numeratorVal, denominatorVal);

if (ceilMode) {
auto remainder = rewriter.create<SignedRemIOp>(
loc, numeratorVal, denominatorVal);
auto isZero = rewriter.create<CmpIOp>(
loc, CmpIPredicate::eq, remainder, zero);
auto dimPlusOne = rewriter.create<AddIOp>(loc, dimVal, one);
dimVal = rewriter.create<SelectOp>(loc, isZero, dimVal, dimPlusOne);
}

dimVal = rewriter.create<AddIOp>(loc, dimVal, one);
allocOperands.emplace_back(rewriter.create<IndexCastOp>(
loc, dimVal, rewriter.getIndexType()));
}
}
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
if (insertDealloc) {
auto *parentBlock = alloc.getDefiningOp()->getBlock();
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
}

// R = MaxPool(D)
//
// The input/output shapes will look like this:
//
// D (NxCxHxW) -> R (NxCxRHxRW)
//
// The loop nest will look as follows:
//
// strides = [s1, s2]
//
// for n = 0 .. N:
// for c = 0 .. C:
// for r1 = 0 .. RH:
// for r2 = 0 .. RW:
// R[n][c][r1][r2] = negative_infinity;
// for k1 = 0 .. KH:
// for k2 = 0 .. KW:
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
//
// Naming:
// n, c, r1, r2: outer loop nest indices
// k1, k2: inner loop nest indices
//
// TODO: handle padding.
//

// 1. Define outer loops and emit empty optimization block.
auto nOuterLoops = resultShape.size();
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
outerLoops.createDefineOptimizeAndIterateOp(alloc);

rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
{
// 2. Emit the body of the outer loop nest.
SmallVector<Value, 4> resultIndices;
for (int i = 0; i < nOuterLoops; ++i)
resultIndices.emplace_back(outerLoops.getInductionVar(i));

// 2.1 Emit: R[n][c][r1][r2] = negative_infinity;
Value identity;
if (resultElementType.isa<FloatType>()) {
identity = rewriter.create<ConstantOp>(
loc, FloatAttr::get(resultElementType,
getIdentityValue<float, ONNXMaxPoolSingleOutOp>()));
} else if (resultElementType.isa<IntegerType>()) {
identity = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(resultElementType,
getIdentityValue<int, ONNXMaxPoolSingleOutOp>()));
} else {
emitError(loc, "unsupported element type");
}
rewriter.create<StoreOp>(loc, identity, alloc, resultIndices);

// 2.2 Define inner loops.
int nInnerLoops = kernelShape.size();
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
innerLoops.createDefineAndOptimizeOp();
// for Kx = 0 .. KX
for (int i = 0; i < nInnerLoops; ++i)
innerLoops.pushBounds(0, kernelShape[i]);

// 2.3 Emit inner loop nest.
innerLoops.createIterateOp();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{
// 3. Emit inner loop body
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);

// 3.1 Prepare indices for accesing the data tensor.
SmallVector<Value, 4> dataIndices;
// Batch indices: n, c
for (int i = 0; i < batchRank; ++i)
dataIndices.emplace_back(outerLoops.getInductionVar(i));
// Spatial indices: sX * rX + kX
for (int i = batchRank; i < nOuterLoops; ++i) {
Value spatialIndex = outerLoops.getInductionVar(i);
// If strides are present then emit the correct access index.
if (stridesAttribute && strides[i - batchRank] > 1) {
spatialIndex = rewriter.create<MulIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, strides[i - batchRank]),
outerLoops.getInductionVar(i));
}
spatialIndex = rewriter.create<AddIOp>(
loc, spatialIndex, innerLoops.getInductionVar(i - batchRank));
// If ceil mode is enabled, then the calculated access index may
// exceed its dimension. In such a case, we will use the maximum
// index, which causes multiple visits to the element of the
// maximum index.
// TODO: Avoid multiple visits.
if (ceilMode) {
Value inputIndex;
if (inputShape[i] < 0) {
Value inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
inputIndex = rewriter.create<SubIOp>(loc, inputDim, one);
} else {
inputIndex =
rewriter.create<ConstantIndexOp>(loc, inputShape[i] - 1);
}
auto greaterCondition = rewriter.create<CmpIOp>(
loc, CmpIPredicate::sgt, spatialIndex, inputIndex);
spatialIndex = rewriter.create<SelectOp>(
loc, greaterCondition, inputIndex, spatialIndex);
}
dataIndices.emplace_back(spatialIndex);
}

// 3.2 Do pooling.
auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
auto loadPartialResult =
rewriter.create<LoadOp>(loc, alloc, resultIndices);
Value result = mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(
op, resultElementType, {loadPartialResult, loadData}, rewriter);
rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
}
}
rewriter.replaceOp(op, alloc);

return matchSuccess();
}
};

void populateLoweringONNXPoolingOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXMaxPoolSingleOutOpLowering>(ctx);
}
3 changes: 3 additions & 0 deletions src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ void populateLoweringONNXConvOpPattern(
void populateLoweringONNXNormalizationOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);

void populateLoweringONNXPoolingOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);

// `tensor` directory methods:

void populateLoweringONNXUnsqueezeOpPattern(
Expand Down
7 changes: 7 additions & 0 deletions test/backend/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ def supports_device(cls, device):
"test_batchnorm_epsilon_cpu",
"test_batchnorm_example_cpu",

# Pooling
"test_maxpool_1d_default_cpu",
"test_maxpool_2d_ceil_cpu",
"test_maxpool_2d_default_cpu",
"test_maxpool_2d_strides_cpu",
"test_maxpool_3d_default_cpu",

]

# Extract name of all test cases.
Expand Down
Loading