Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite ScatterOp to ScatterElementsOp #1337

Merged
merged 2 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "src/Pass/Passes.hpp"

using namespace mlir;
using namespace onnx_mlir;

namespace onnx_mlir {

Expand Down
21 changes: 10 additions & 11 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4627,10 +4627,9 @@ mlir::Operation::result_range ONNXScanOp::scan_outputs() {
return llvm::make_range(results.begin() + v_initial().size(), results.end());
}

LogicalResult ONNXScatterOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
return emitError(NOT_IMPLEMENTED_MESSAGE);
}
//===----------------------------------------------------------------------===//
// ONNXScatterElements
//===----------------------------------------------------------------------===//

LogicalResult ONNXScatterElementsOp::verify() {
ONNXScatterElementsOpAdaptor operandAdaptor(*this);
Expand Down Expand Up @@ -4996,7 +4995,6 @@ LogicalResult ONNXZipMapOp::inferShapes(
return emitError(NOT_IMPLEMENTED_MESSAGE);
}

// New Ops from onnx1.8.1
#define NOT_IMPLEMENTED_INFERSHAPE(T) \
LogicalResult T::inferShapes( \
std::function<void(mlir::Region &)> doShapeInference) { \
Expand All @@ -5006,20 +5004,21 @@ LogicalResult ONNXZipMapOp::inferShapes(
NOT_IMPLEMENTED_INFERSHAPE(ONNXAdagradOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXAdamOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXCeluOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV6Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV12Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXEinsumOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXGradientOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXMomentumOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXNegativeLogLikelihoodLossOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXSoftmaxCrossEntropyLossOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV9Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV7Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV2Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV10Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV6Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV11Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV12Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXScatterOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXSoftmaxCrossEntropyLossOp)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV9Op)
NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV7Op)

//===----------------------------------------------------------------------===//
// Loop
Expand Down
75 changes: 43 additions & 32 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===----------- ONNXDecompose.cpp - ONNX High Level Rewriting ------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand All @@ -13,7 +13,7 @@
//
// This pass is applied before any other pass so that there is no need to
// implement shape inference for the decomposed operation. Hence, it is expected
// that there is no knowledge about tensor shape at this point
// that there is no knowledge about tensor shape at this point.
//
//===----------------------------------------------------------------------===//

Expand All @@ -28,35 +28,39 @@

using namespace mlir;

namespace onnx_mlir {

// Create an DenseElementsAttr of ArrayAttr.
// This function is used to get Value Type of an EXISTING ArrayAttr for Scaler
// function.
DenseElementsAttr createDenseArrayAttr(
PatternRewriter &rewriter, ArrayAttr origAttrs) {

assert(origAttrs && "handle EXISTING ArrayAttr only");

if (origAttrs.getValue()[0].dyn_cast<FloatAttr>()) {
mlir::Type elementType = rewriter.getF32Type();
Type elementType = rewriter.getF32Type();
int nElements = origAttrs.getValue().size();
SmallVector<float, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i) {
for (int i = 0; i < nElements; ++i)
wrapper[i] = origAttrs.getValue()[i].cast<FloatAttr>().getValueAsDouble();
}

return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

if (origAttrs.getValue()[0].dyn_cast<IntegerAttr>()) {
mlir::Type elementType = rewriter.getIntegerType(64);
Type elementType = rewriter.getIntegerType(64);
int nElements = origAttrs.getValue().size();
SmallVector<int64_t, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i) {
for (int i = 0; i < nElements; ++i)
wrapper[i] = origAttrs.getValue()[i].cast<IntegerAttr>().getInt();
}

return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

llvm_unreachable("unexpected attribute type");
}

Expand All @@ -65,19 +69,21 @@ DenseElementsAttr createDenseArrayAttr(
DenseElementsAttr createScalarDenseAttr(
PatternRewriter &rewriter, Attribute attr) {
if (attr.dyn_cast<FloatAttr>()) {
mlir::Type elementType = rewriter.getF32Type();
Type elementType = rewriter.getF32Type();
SmallVector<float, 1> wrapper;
wrapper.emplace_back(attr.cast<FloatAttr>().getValueAsDouble());
return DenseElementsAttr::get(
RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper));
}

if (attr.dyn_cast<IntegerAttr>()) {
mlir::Type elementType = rewriter.getIntegerType(64);
Type elementType = rewriter.getIntegerType(64);
SmallVector<int64_t, 1> wrapper;
wrapper.emplace_back(attr.cast<IntegerAttr>().getInt());
return DenseElementsAttr::get(
RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper));
}

llvm_unreachable("unexpected attribute type");
}

Expand All @@ -89,20 +95,18 @@ Value createUnitConstant(PatternRewriter &rewriter, Location loc) {
// When ArrayAttr is Null, an empty Integer DenseElementAttr is returned
DenseElementsAttr createDenseArrayAttrOrEmpty(
PatternRewriter &rewriter, ArrayAttr origAttrs) {

if (origAttrs) {
if (origAttrs)
return createDenseArrayAttr(rewriter, origAttrs);
} else {
mlir::Type elementType = rewriter.getIntegerType(64);
int nElements = 0;
SmallVector<int64_t, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i) {
wrapper[i] = i;
}
return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

Type elementType = rewriter.getIntegerType(64);
int nElements = 0;
SmallVector<int64_t, 4> wrapper(nElements, 0);
for (int i = 0; i < nElements; ++i)
wrapper[i] = i;

return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}

Value createSequenceConstructOp(
Expand All @@ -111,17 +115,18 @@ Value createSequenceConstructOp(
Location loc = seq.getLoc();
Value position = rewriter.create<ONNXNoneOp>(loc);

for (auto input : inputs) {
for (auto input : inputs)
seq = rewriter.create<ONNXSequenceInsertOp>(
loc, resType, seq, input, position);
}

return seq;
}

/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc"
} // namespace onnx_mlir

namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc"

struct DecomposeONNXToONNXPass
: public PassWrapper<DecomposeONNXToONNXPass, OperationPass<FuncOp>> {
Expand All @@ -135,10 +140,9 @@ struct DecomposeONNXToONNXPass

void runOnOperation() final;
};
} // end anonymous namespace.

void DecomposeONNXToONNXPass::runOnOperation() {
auto function = getOperation();
FuncOp function = getOperation();
MLIRContext *context = &getContext();

ConversionTarget target(getContext());
Expand All @@ -161,6 +165,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXResizeV11Op>();
target.addIllegalOp<ONNXResizeV10Op>();
target.addIllegalOp<ONNXScalerOp>();
target.addIllegalOp<ONNXScatterOp>();
target.addIllegalOp<ONNXSequenceConstructOp>();
target.addIllegalOp<ONNXUpsampleOp>();
target.addIllegalOp<ONNXUpsampleV9Op>();
Expand All @@ -171,11 +176,17 @@ void DecomposeONNXToONNXPass::runOnOperation() {

if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
} // end anonymous namespace
}

} // namespace

namespace onnx_mlir {

/*!
* Create a DecomposeONNX pass.
*/
std::unique_ptr<mlir::Pass> onnx_mlir::createDecomposeONNXToONNXPass() {
std::unique_ptr<mlir::Pass> createDecomposeONNXToONNXPass() {
return std::make_unique<DecomposeONNXToONNXPass>();
}

} // namespace onnx_mlir
31 changes: 17 additions & 14 deletions src/Transform/ONNX/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,31 @@ def IsNoneType : Constraint<CPred<"(($_self).getType().isa<NoneType>())">>;

def GetNullAttr : NativeCodeCall<"Attribute()">;

def GetNullFloatAttr :
NativeCodeCall<"FloatAttr()">;
def GetNullFloatAttr : NativeCodeCall<"FloatAttr()">;

def GetNullIntegerAttr :
NativeCodeCall<"IntegerAttr()">;
def GetNullIntegerAttr : NativeCodeCall<"IntegerAttr()">;

def GetNullStringAttr :
NativeCodeCall<"StringAttr()">;
def GetNullStringAttr : NativeCodeCall<"StringAttr()">;

// Create a unit constant that will be used as none input.
def CreateUnitConstant
: NativeCodeCall<"createUnitConstant($_builder, $_loc)">;
: NativeCodeCall<"::onnx_mlir::createUnitConstant($_builder, $_loc)">;

// Create a scalar DenseElementsAttr (rank 0) from a single attribute.
// E.g return type is tensor<f32> instead of tensor<0xf32> or tensor<1xf32>
def createScalarDenseAttrRank0
: NativeCodeCall<"createScalarDenseAttr($_builder, $0)">;
: NativeCodeCall<"::onnx_mlir::createScalarDenseAttr($_builder, $0)">;

// Create a DenseElementsAttr from a single attribute.
def createDenseArrayAttrFromSingleAttr
: NativeCodeCall<"createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">;
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">;

// Create a DenseElementsAttr from an ArrayAttr.
def createDenseArrayAttr
: NativeCodeCall<"createDenseArrayAttr($_builder, $0)">;
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $0)">;

def createDenseArrayAttrOrEmpty
: NativeCodeCall<"createDenseArrayAttrOrEmpty($_builder, $0)">;
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttrOrEmpty($_builder, $0)">;

def ScalerT : NativeCodeCall<"TypeAttr::get($_builder.getF32Type())">;

Expand All @@ -80,14 +77,14 @@ def noop_with_empty_axes

// Check if value is a dense constant
def IsDenseONNXConstant:
Constraint<CPred<"isDenseONNXConstant($_self)">, "not a dense constant">;
Constraint<CPred<"::onnx_mlir::isDenseONNXConstant($_self)">, "not a dense constant">;

// Create an ArrayAttr from a dense ConstantOp
def createArrayAttrFromConstantOp : NativeCodeCall<
"createArrayAttrFromConstantOp($_builder, $0)">;
"::onnx_mlir::createArrayAttrFromConstantOp($_builder, $0)">;

def createSequenceConstructOp : NativeCodeCall<
"createSequenceConstructOp($_builder, $0, $1)">;
"::onnx_mlir::createSequenceConstructOp($_builder, $0, $1)">;

//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
Expand Down Expand Up @@ -344,4 +341,10 @@ def ClipV12Pattern : Pat<
(ONNXClipOp $x, $min, $max)
>;

// Express Scatter (deprecated) using ScatterElements.
def ScatterPattern : Pat<
(ONNXScatterOp $data, $indices, $updates, $axis),
(ONNXScatterElementsOp $data, $indices, $updates, $axis)
>;

#endif // ONNX_DECOMPOSE
Loading