Skip to content

Commit

Permalink
Rewrite ScatterOp to ScatterElementsOp (#1337)
Browse files Browse the repository at this point in the history
Rewrite the deprecated ONNX Scatter operation (https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scatter) using the equivalent ScatterElements operation.
  • Loading branch information
Ettore Tiotto authored Apr 13, 2022
1 parent e393319 commit fa66317
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 88 deletions.
1 change: 0 additions & 1 deletion src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "src/Pass/Passes.hpp"

using namespace mlir;
using namespace onnx_mlir;

namespace onnx_mlir {

Expand Down
21 changes: 10 additions & 11 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4627,10 +4627,9 @@ mlir::Operation::result_range ONNXScanOp::scan_outputs() {
return llvm::make_range(results.begin() + v_initial().size(), results.end());
}

LogicalResult ONNXScatterOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
return emitError(NOT_IMPLEMENTED_MESSAGE);
}
//===----------------------------------------------------------------------===//
// ONNXScatterElements
//===----------------------------------------------------------------------===//

LogicalResult ONNXScatterElementsOp::verify() {
ONNXScatterElementsOpAdaptor operandAdaptor(*this);
Expand Down Expand Up @@ -4996,7 +4995,6 @@ LogicalResult ONNXZipMapOp::inferShapes(
return emitError(NOT_IMPLEMENTED_MESSAGE);
}

// New Ops from onnx1.8.1
#define NOT_IMPLEMENTED_INFERSHAPE(T) \
LogicalResult T::inferShapes( \
std::function<void(mlir::Region &)> doShapeInference) { \
Expand All @@ -5006,20 +5004,21 @@ LogicalResult ONNXZipMapOp::inferShapes(
NOT_IMPLEMENTED_INFERSHAPE(ONNXAdagradOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXAdamOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXCeluOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV6Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV12Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXEinsumOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXGradientOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXMomentumOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXNegativeLogLikelihoodLossOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXSoftmaxCrossEntropyLossOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV9Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV7Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV2Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV10Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV6Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV12Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXScatterOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXSoftmaxCrossEntropyLossOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV9Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV7Op)

//===----------------------------------------------------------------------===//
// Loop
Expand Down
75 changes: 43 additions & 32 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===----------- ONNXDecompose.cpp - ONNX High Level Rewriting ------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand All @@ -13,7 +13,7 @@
//
// This pass is applied before any other pass so that there is no need to
// implement shape inference for the decomposed operation. Hence, it is expected
// that there is no knowledge about tensor shape at this point
// that there is no knowledge about tensor shape at this point.
//
//===----------------------------------------------------------------------===//

Expand All @@ -28,35 +28,39 @@

using namespace mlir;

namespace onnx_mlir {

// Create an DenseElementsAttr of ArrayAttr.
// This function is used to get Value Type of an EXISTING ArrayAttr for Scaler
// function.
DenseElementsAttr createDenseArrayAttr(
PatternRewriter &rewriter, ArrayAttr origAttrs) {

assert(origAttrs && "handle EXISTING ArrayAttr only");

if (origAttrs.getValue()[0].dyn_cast<FloatAttr>()) {
mlir::Type elementType = rewriter.getF32Type();
Type elementType = rewriter.getF32Type();
int nElements = origAttrs.getValue().size();
SmallVector<float, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i) {
for (int i = 0; i < nElements; ++i)
wrapper[i] = origAttrs.getValue()[i].cast<FloatAttr>().getValueAsDouble();
}

return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

if (origAttrs.getValue()[0].dyn_cast<IntegerAttr>()) {
mlir::Type elementType = rewriter.getIntegerType(64);
Type elementType = rewriter.getIntegerType(64);
int nElements = origAttrs.getValue().size();
SmallVector<int64_t, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i) {
for (int i = 0; i < nElements; ++i)
wrapper[i] = origAttrs.getValue()[i].cast<IntegerAttr>().getInt();
}

return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

llvm_unreachable("unexpected attribute type");
}

Expand All @@ -65,19 +69,21 @@ DenseElementsAttr createDenseArrayAttr(
DenseElementsAttr createScalarDenseAttr(
PatternRewriter &rewriter, Attribute attr) {
if (attr.dyn_cast<FloatAttr>()) {
mlir::Type elementType = rewriter.getF32Type();
Type elementType = rewriter.getF32Type();
SmallVector<float, 1> wrapper;
wrapper.emplace_back(attr.cast<FloatAttr>().getValueAsDouble());
return DenseElementsAttr::get(
RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper));
}

if (attr.dyn_cast<IntegerAttr>()) {
mlir::Type elementType = rewriter.getIntegerType(64);
Type elementType = rewriter.getIntegerType(64);
SmallVector<int64_t, 1> wrapper;
wrapper.emplace_back(attr.cast<IntegerAttr>().getInt());
return DenseElementsAttr::get(
RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper));
}

llvm_unreachable("unexpected attribute type");
}

Expand All @@ -89,20 +95,18 @@ Value createUnitConstant(PatternRewriter &rewriter, Location loc) {
// When ArrayAttr is Null, an empty Integer DenseElementAttr is returned
DenseElementsAttr createDenseArrayAttrOrEmpty(
PatternRewriter &rewriter, ArrayAttr origAttrs) {

if (origAttrs) {
if (origAttrs)
return createDenseArrayAttr(rewriter, origAttrs);
} else {
mlir::Type elementType = rewriter.getIntegerType(64);
int nElements = 0;
SmallVector<int64_t, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i) {
wrapper[i] = i;
}
return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

Type elementType = rewriter.getIntegerType(64);
int nElements = 0;
SmallVector<int64_t, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i)
wrapper[i] = i;

return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

Value createSequenceConstructOp(
Expand All @@ -111,17 +115,18 @@ Value createSequenceConstructOp(
Location loc = seq.getLoc();
Value position = rewriter.create<ONNXNoneOp>(loc);

for (auto input : inputs) {
for (auto input : inputs)
seq = rewriter.create<ONNXSequenceInsertOp>(
loc, resType, seq, input, position);
}

return seq;
}

/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc"
} // namespace onnx_mlir

namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc"

struct DecomposeONNXToONNXPass
: public PassWrapper<DecomposeONNXToONNXPass, OperationPass<FuncOp>> {
Expand All @@ -135,10 +140,9 @@ struct DecomposeONNXToONNXPass

void runOnOperation() final;
};
} // end anonymous namespace.

void DecomposeONNXToONNXPass::runOnOperation() {
auto function = getOperation();
FuncOp function = getOperation();
MLIRContext *context = &getContext();

ConversionTarget target(getContext());
Expand All @@ -161,6 +165,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXResizeV11Op>();
target.addIllegalOp<ONNXResizeV10Op>();
target.addIllegalOp<ONNXScalerOp>();
target.addIllegalOp<ONNXScatterOp>();
target.addIllegalOp<ONNXSequenceConstructOp>();
target.addIllegalOp<ONNXUpsampleOp>();
target.addIllegalOp<ONNXUpsampleV9Op>();
Expand All @@ -171,11 +176,17 @@ void DecomposeONNXToONNXPass::runOnOperation() {

if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
} // end anonymous namespace
}

} // namespace

namespace onnx_mlir {

/*!
* Create a DecomposeONNX pass.
*/
std::unique_ptr<mlir::Pass> onnx_mlir::createDecomposeONNXToONNXPass() {
std::unique_ptr<mlir::Pass> createDecomposeONNXToONNXPass() {
return std::make_unique<DecomposeONNXToONNXPass>();
}

} // namespace onnx_mlir
31 changes: 17 additions & 14 deletions src/Transform/ONNX/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,31 @@ def IsNoneType : Constraint<CPred<"(($_self).getType().isa<NoneType>())">>;

def GetNullAttr : NativeCodeCall<"Attribute()">;

def GetNullFloatAttr :
NativeCodeCall<"FloatAttr()">;
def GetNullFloatAttr : NativeCodeCall<"FloatAttr()">;

def GetNullIntegerAttr :
NativeCodeCall<"IntegerAttr()">;
def GetNullIntegerAttr : NativeCodeCall<"IntegerAttr()">;

def GetNullStringAttr :
NativeCodeCall<"StringAttr()">;
def GetNullStringAttr : NativeCodeCall<"StringAttr()">;

// Create a unit constant that will be used as none input.
def CreateUnitConstant
: NativeCodeCall<"createUnitConstant($_builder, $_loc)">;
: NativeCodeCall<"::onnx_mlir::createUnitConstant($_builder, $_loc)">;

// Create a scalar DenseElementsAttr (rank 0) from a single attribute.
// E.g return type is tensor<f32> instead of tensor<0xf32> or tensor<1xf32>
def createScalarDenseAttrRank0
: NativeCodeCall<"createScalarDenseAttr($_builder, $0)">;
: NativeCodeCall<"::onnx_mlir::createScalarDenseAttr($_builder, $0)">;

// Create a DenseElementsAttr from a single attribute.
def createDenseArrayAttrFromSingleAttr
: NativeCodeCall<"createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">;
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">;

// Create a DenseElementsAttr from an ArrayAttr.
def createDenseArrayAttr
: NativeCodeCall<"createDenseArrayAttr($_builder, $0)">;
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $0)">;

def createDenseArrayAttrOrEmpty
: NativeCodeCall<"createDenseArrayAttrOrEmpty($_builder, $0)">;
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttrOrEmpty($_builder, $0)">;

def ScalerT : NativeCodeCall<"TypeAttr::get($_builder.getF32Type())">;

Expand All @@ -80,14 +77,14 @@ def noop_with_empty_axes

// Check if value is a dense constant
def IsDenseONNXConstant:
Constraint<CPred<"isDenseONNXConstant($_self)">, "not a dense constant">;
Constraint<CPred<"::onnx_mlir::isDenseONNXConstant($_self)">, "not a dense constant">;

// Create an ArrayAttr from a dense ConstantOp
def createArrayAttrFromConstantOp : NativeCodeCall<
"createArrayAttrFromConstantOp($_builder, $0)">;
"::onnx_mlir::createArrayAttrFromConstantOp($_builder, $0)">;

def createSequenceConstructOp : NativeCodeCall<
"createSequenceConstructOp($_builder, $0, $1)">;
"::onnx_mlir::createSequenceConstructOp($_builder, $0, $1)">;

//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
Expand Down Expand Up @@ -344,4 +341,10 @@ def ClipV12Pattern : Pat<
(ONNXClipOp $x, $min, $max)
>;

// Express Scatter (deprecated) using ScatterElements.
def ScatterPattern : Pat<
(ONNXScatterOp $data, $indices, $updates, $axis),
(ONNXScatterElementsOp $data, $indices, $updates, $axis)
>;

#endif // ONNX_DECOMPOSE
Loading

0 comments on commit fa66317

Please sign in to comment.