diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index fb1ba29641..1472e9c2d4 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -28,7 +28,6 @@ #include "src/Pass/Passes.hpp" using namespace mlir; -using namespace onnx_mlir; namespace onnx_mlir { diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 5dac368219..fa2acf0865 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -4627,10 +4627,9 @@ mlir::Operation::result_range ONNXScanOp::scan_outputs() { return llvm::make_range(results.begin() + v_initial().size(), results.end()); } -LogicalResult ONNXScatterOp::inferShapes( - std::function doShapeInference) { - return emitError(NOT_IMPLEMENTED_MESSAGE); -} +//===----------------------------------------------------------------------===// +// ONNXScatterElements +//===----------------------------------------------------------------------===// LogicalResult ONNXScatterElementsOp::verify() { ONNXScatterElementsOpAdaptor operandAdaptor(*this); @@ -4996,7 +4995,6 @@ LogicalResult ONNXZipMapOp::inferShapes( return emitError(NOT_IMPLEMENTED_MESSAGE); } -// New Ops from onnx1.8.1 #define NOT_IMPLEMENTED_INFERSHAPE(T) \ LogicalResult T::inferShapes( \ std::function doShapeInference) { \ @@ -5006,20 +5004,21 @@ LogicalResult ONNXZipMapOp::inferShapes( NOT_IMPLEMENTED_INFERSHAPE(ONNXAdagradOp) NOT_IMPLEMENTED_INFERSHAPE(ONNXAdamOp) NOT_IMPLEMENTED_INFERSHAPE(ONNXCeluOp) +NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV6Op) +NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV11Op) +NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV12Op) NOT_IMPLEMENTED_INFERSHAPE(ONNXEinsumOp) NOT_IMPLEMENTED_INFERSHAPE(ONNXGradientOp) NOT_IMPLEMENTED_INFERSHAPE(ONNXMomentumOp) NOT_IMPLEMENTED_INFERSHAPE(ONNXNegativeLogLikelihoodLossOp) -NOT_IMPLEMENTED_INFERSHAPE(ONNXSoftmaxCrossEntropyLossOp) -NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV9Op) -NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV7Op) NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV2Op) NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV11Op) NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV11Op) NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV10Op) -NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV6Op) -NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV11Op) -NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV12Op) +NOT_IMPLEMENTED_INFERSHAPE(ONNXScatterOp) +NOT_IMPLEMENTED_INFERSHAPE(ONNXSoftmaxCrossEntropyLossOp) +NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV9Op) +NOT_IMPLEMENTED_INFERSHAPE(ONNXUpsampleV7Op) //===----------------------------------------------------------------------===// // Loop diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Transform/ONNX/Decompose.cpp index 4684bc07bd..10a981da62 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Transform/ONNX/Decompose.cpp @@ -4,7 +4,7 @@ //===----------- ONNXDecompose.cpp - ONNX High Level Rewriting ------------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2022 The IBM Research Authors. // // ============================================================================= // @@ -13,7 +13,7 @@ // // This pass is applied before any other pass so that there is no need to // implement shape inference for the decomposed operation. Hence, it is expected -// that there is no knowledge about tensor shape at this point +// that there is no knowledge about tensor shape at this point. // //===----------------------------------------------------------------------===// @@ -28,35 +28,39 @@ using namespace mlir; +namespace onnx_mlir { + // Create an DenseElementsAttr of ArrayAttr. // This function is used to get Value Type of an EXISTING ArrayAttr for Scaler // function. DenseElementsAttr createDenseArrayAttr( PatternRewriter &rewriter, ArrayAttr origAttrs) { - assert(origAttrs && "handle EXISTING ArrayAttr only"); + if (origAttrs.getValue()[0].dyn_cast()) { - mlir::Type elementType = rewriter.getF32Type(); + Type elementType = rewriter.getF32Type(); int nElements = origAttrs.getValue().size(); SmallVector wrapper(nElements, 0); - for (int i = 0; i < nElements; ++i) { + for (int i = 0; i < nElements; ++i) wrapper[i] = origAttrs.getValue()[i].cast().getValueAsDouble(); - } + return DenseElementsAttr::get( RankedTensorType::get(wrapper.size(), elementType), llvm::makeArrayRef(wrapper)); } + if (origAttrs.getValue()[0].dyn_cast()) { - mlir::Type elementType = rewriter.getIntegerType(64); + Type elementType = rewriter.getIntegerType(64); int nElements = origAttrs.getValue().size(); SmallVector wrapper(nElements, 0); - for (int i = 0; i < nElements; ++i) { + for (int i = 0; i < nElements; ++i) wrapper[i] = origAttrs.getValue()[i].cast().getInt(); - } + return DenseElementsAttr::get( RankedTensorType::get(wrapper.size(), elementType), llvm::makeArrayRef(wrapper)); } + llvm_unreachable("unexpected attribute type"); } @@ -65,19 +69,21 @@ DenseElementsAttr createDenseArrayAttr( DenseElementsAttr createScalarDenseAttr( PatternRewriter &rewriter, Attribute attr) { if (attr.dyn_cast()) { - mlir::Type elementType = rewriter.getF32Type(); + Type elementType = rewriter.getF32Type(); SmallVector wrapper; wrapper.emplace_back(attr.cast().getValueAsDouble()); return DenseElementsAttr::get( RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper)); } + if (attr.dyn_cast()) { - mlir::Type elementType = rewriter.getIntegerType(64); + Type elementType = rewriter.getIntegerType(64); SmallVector wrapper; wrapper.emplace_back(attr.cast().getInt()); return DenseElementsAttr::get( RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper)); } + llvm_unreachable("unexpected attribute type"); } @@ -89,20 +95,18 @@ Value createUnitConstant(PatternRewriter &rewriter, Location loc) { // When ArrayAttr is Null, an empty Integer DenseElementAttr is returned DenseElementsAttr createDenseArrayAttrOrEmpty( PatternRewriter &rewriter, ArrayAttr origAttrs) { - - if (origAttrs) { + if (origAttrs) return createDenseArrayAttr(rewriter, origAttrs); - } else { - mlir::Type elementType = rewriter.getIntegerType(64); - int nElements = 0; - SmallVector wrapper(nElements, 0); - for (int i = 0; i < nElements; ++i) { - wrapper[i] = i; - } - return DenseElementsAttr::get( - RankedTensorType::get(wrapper.size(), elementType), - llvm::makeArrayRef(wrapper)); - } + + Type elementType = rewriter.getIntegerType(64); + int nElements = 0; + SmallVector wrapper(nElements, 0); + for (int i = 0; i < nElements; ++i) + wrapper[i] = i; + + return DenseElementsAttr::get( + RankedTensorType::get(wrapper.size(), elementType), + llvm::makeArrayRef(wrapper)); } Value createSequenceConstructOp( @@ -111,17 +115,18 @@ Value createSequenceConstructOp( Location loc = seq.getLoc(); Value position = rewriter.create(loc); - for (auto input : inputs) { + for (auto input : inputs) seq = rewriter.create( loc, resType, seq, input, position); - } + return seq; } -/// Include the patterns defined in the Declarative Rewrite framework. -#include "src/Transform/ONNX/ONNXDecompose.inc" +} // namespace onnx_mlir namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "src/Transform/ONNX/ONNXDecompose.inc" struct DecomposeONNXToONNXPass : public PassWrapper> { @@ -135,10 +140,9 @@ struct DecomposeONNXToONNXPass void runOnOperation() final; }; -} // end anonymous namespace. void DecomposeONNXToONNXPass::runOnOperation() { - auto function = getOperation(); + FuncOp function = getOperation(); MLIRContext *context = &getContext(); ConversionTarget target(getContext()); @@ -161,6 +165,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -171,11 +176,17 @@ void DecomposeONNXToONNXPass::runOnOperation() { if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); -} // end anonymous namespace +} + +} // namespace + +namespace onnx_mlir { /*! * Create a DecomposeONNX pass. */ -std::unique_ptr onnx_mlir::createDecomposeONNXToONNXPass() { +std::unique_ptr createDecomposeONNXToONNXPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/Transform/ONNX/Decompose.td b/src/Transform/ONNX/Decompose.td index 3bcaee8ea7..7d9c18e742 100644 --- a/src/Transform/ONNX/Decompose.td +++ b/src/Transform/ONNX/Decompose.td @@ -43,34 +43,31 @@ def IsNoneType : Constraint())">>; def GetNullAttr : NativeCodeCall<"Attribute()">; -def GetNullFloatAttr : - NativeCodeCall<"FloatAttr()">; +def GetNullFloatAttr : NativeCodeCall<"FloatAttr()">; -def GetNullIntegerAttr : - NativeCodeCall<"IntegerAttr()">; +def GetNullIntegerAttr : NativeCodeCall<"IntegerAttr()">; -def GetNullStringAttr : - NativeCodeCall<"StringAttr()">; +def GetNullStringAttr : NativeCodeCall<"StringAttr()">; // Create a unit constant that will be used as none input. def CreateUnitConstant - : NativeCodeCall<"createUnitConstant($_builder, $_loc)">; + : NativeCodeCall<"::onnx_mlir::createUnitConstant($_builder, $_loc)">; // Create a scalar DenseElementsAttr (rank 0) from a single attribute. // E.g return type is tensor instead of tensor<0xf32> or tensor<1xf32> def createScalarDenseAttrRank0 - : NativeCodeCall<"createScalarDenseAttr($_builder, $0)">; + : NativeCodeCall<"::onnx_mlir::createScalarDenseAttr($_builder, $0)">; // Create a DenseElementsAttr from a single attribute. def createDenseArrayAttrFromSingleAttr - : NativeCodeCall<"createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">; + : NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">; // Create a DenseElementsAttr from an ArrayAttr. def createDenseArrayAttr - : NativeCodeCall<"createDenseArrayAttr($_builder, $0)">; + : NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $0)">; def createDenseArrayAttrOrEmpty - : NativeCodeCall<"createDenseArrayAttrOrEmpty($_builder, $0)">; + : NativeCodeCall<"::onnx_mlir::createDenseArrayAttrOrEmpty($_builder, $0)">; def ScalerT : NativeCodeCall<"TypeAttr::get($_builder.getF32Type())">; @@ -80,14 +77,14 @@ def noop_with_empty_axes // Check if value is a dense constant def IsDenseONNXConstant: - Constraint, "not a dense constant">; + Constraint, "not a dense constant">; // Create an ArrayAttr from a dense ConstantOp def createArrayAttrFromConstantOp : NativeCodeCall< - "createArrayAttrFromConstantOp($_builder, $0)">; + "::onnx_mlir::createArrayAttrFromConstantOp($_builder, $0)">; def createSequenceConstructOp : NativeCodeCall< - "createSequenceConstructOp($_builder, $0, $1)">; + "::onnx_mlir::createSequenceConstructOp($_builder, $0, $1)">; //===----------------------------------------------------------------------===// // ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) @@ -344,4 +341,10 @@ def ClipV12Pattern : Pat< (ONNXClipOp $x, $min, $max) >; +// Express Scatter (deprecated) using ScatterElements. +def ScatterPattern : Pat< + (ONNXScatterOp $data, $indices, $updates, $axis), + (ONNXScatterElementsOp $data, $indices, $updates, $axis) +>; + #endif // ONNX_DECOMPOSE diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 39c783bb77..9585cfecf0 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -157,10 +157,10 @@ func @test_scaler_no_scale2(%arg0: tensor<3xi32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32> return %0 : tensor<3xf32> - //CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = f32} : (tensor<3xi32>) -> tensor<*xf32> - //CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> - //CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> - //CHECK-NEXT: return %2 : tensor<3xf32> + // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = f32} : (tensor<3xi32>) -> tensor<*xf32> + // CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> + // CHECK-NEXT: return %2 : tensor<3xf32> } // ----- @@ -276,12 +276,13 @@ func @test_padv2(%arg0: tensor<1x3x224x224xf32>) -> tensor<*xf32> { func @test_resizev10(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<4xf32>) -> tensor<*xf32> { %0 = "onnx.ResizeV10"(%arg0, %arg1) {mode = "nearest"} : (tensor<1x2x3x4xf32>, tensor<4xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> -// CHECK-LABEL: func @test_resizev10 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>) -> tensor<*xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xf32>} : () -> tensor<0xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> -// CHECK: [[VAR_2_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]], [[VAR_1_]]) {mode = "nearest"} : (tensor<1x2x3x4xf32>, tensor<0xf32>, tensor<4xf32>, tensor<0xi64>) -> tensor<*xf32> -// CHECK: return [[VAR_2_]] : tensor<*xf32> + + // CHECK-LABEL: func @test_resizev10 + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>) -> tensor<*xf32> { + // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xf32>} : () -> tensor<0xf32> + // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> + // CHECK: [[VAR_2_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]], [[VAR_1_]]) {mode = "nearest"} : (tensor<1x2x3x4xf32>, tensor<0xf32>, tensor<4xf32>, tensor<0xi64>) -> tensor<*xf32> + // CHECK: return [[VAR_2_]] : tensor<*xf32> } func @test_resizev11(%arg0: tensor<*xf32>, %arg1: tensor<*xi64>) -> tensor<*xf32> { @@ -289,24 +290,26 @@ func @test_padv2(%arg0: tensor<1x3x224x224xf32>) -> tensor<*xf32> { %1 = "onnx.Constant"() {value = dense<> : tensor<0xf32>} : () -> tensor<0xf32> %2 = "onnx.Resize"(%arg0, %0, %0, %arg1) {coordinate_transformation_mode = "half_pixel", mode = "nearest", nearest_mode = "floor", onnx_node_name = "Resize__697"} : (tensor<*xf32>, tensor<0xf32>, tensor<0xf32>, tensor<*xi64>) -> tensor<*xf32> return %2 : tensor<*xf32> -// CHECK-LABEL: func @test_resizev11 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xi64>) -> tensor<*xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xf32>} : () -> tensor<0xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xf32>} : () -> tensor<0xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[VAR_0_]], [[PARAM_1_]]) {coordinate_transformation_mode = "half_pixel", mode = "nearest", nearest_mode = "floor", onnx_node_name = "Resize__697"} : (tensor<*xf32>, tensor<0xf32>, tensor<0xf32>, tensor<*xi64>) -> tensor<*xf32> -// CHECK: return [[VAR_2_]] : tensor<*xf32> + + // CHECK-LABEL: func @test_resizev11 + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xi64>) -> tensor<*xf32> { + // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xf32>} : () -> tensor<0xf32> + // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<> : tensor<0xf32>} : () -> tensor<0xf32> + // CHECK: [[VAR_2_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[VAR_0_]], [[PARAM_1_]]) {coordinate_transformation_mode = "half_pixel", mode = "nearest", nearest_mode = "floor", onnx_node_name = "Resize__697"} : (tensor<*xf32>, tensor<0xf32>, tensor<0xf32>, tensor<*xi64>) -> tensor<*xf32> + // CHECK: return [[VAR_2_]] : tensor<*xf32> } func @test_seqence_construct_1(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> !onnx.Seq> { %0 = "onnx.SequenceConstruct"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> !onnx.Seq> return %0 : !onnx.Seq> -// CHECK-LABEL: func @test_seqence_construct_1 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> !onnx.Seq> { -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.SequenceEmpty"() : () -> !onnx.Seq> -// CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_1_:%.+]] = "onnx.SequenceInsert"([[VAR_0_]], [[PARAM_0_]], [[VAR_cst_]]) : (!onnx.Seq>, tensor<*xf32>, none) -> !onnx.Seq> -// CHECK: [[VAR_2_:%.+]] = "onnx.SequenceInsert"([[VAR_1_]], [[PARAM_1_]], [[VAR_cst_]]) : (!onnx.Seq>, tensor<*xf32>, none) -> !onnx.Seq> -// CHECK: return [[VAR_2_]] : !onnx.Seq> + + // CHECK-LABEL: func @test_seqence_construct_1 + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> !onnx.Seq> { + // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.SequenceEmpty"() : () -> !onnx.Seq> + // CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none + // CHECK: [[VAR_1_:%.+]] = "onnx.SequenceInsert"([[VAR_0_]], [[PARAM_0_]], [[VAR_cst_]]) : (!onnx.Seq>, tensor<*xf32>, none) -> !onnx.Seq> + // CHECK: [[VAR_2_:%.+]] = "onnx.SequenceInsert"([[VAR_1_]], [[PARAM_1_]], [[VAR_cst_]]) : (!onnx.Seq>, tensor<*xf32>, none) -> !onnx.Seq> + // CHECK: return [[VAR_2_]] : !onnx.Seq> } // ----- @@ -315,11 +318,23 @@ func @test_clipv6(%arg0 : tensor<*xf32>) -> () { %0 = "onnx.ClipV6"(%arg0) {max = 6.000000e+00 : f32, min = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> return -// CHECK-LABEL: func @test_clipv6 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) { -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<6.000000e+00> : tensor} : () -> tensor -// CHECK: [[VAR_2_:%.+]] = "onnx.Clip"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> -// CHECK: return -// CHECK: } + // CHECK-LABEL: func @test_clipv6 + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) { + // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor + // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<6.000000e+00> : tensor} : () -> tensor + // CHECK: [[VAR_2_:%.+]] = "onnx.Clip"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> + // CHECK: return } + +// ----- + +func @test_scatter(%arg0: tensor<64x25600xf32>, %arg1: tensor<64x100xi64>, %arg2: tensor<64x100xf32>) -> tensor<*xf32> { + %0 = "onnx.Scatter"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<64x25600xf32>, tensor<64x100xi64>, tensor<64x100xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + + // CHECK-LABEL: func @test_scatter + // CHECK-SAME: ([[PARAM_0:%.+]]: tensor<64x25600xf32>, [[PARAM_1:%.+]]: tensor<64x100xi64>, [[PARAM_2:%.+]]: tensor<64x100xf32>) -> tensor<*xf32> { + // CHECK-NEXT: [[RES:%.+]] = "onnx.ScatterElements"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<64x25600xf32>, tensor<64x100xi64>, tensor<64x100xf32>) -> tensor<*xf32> + // CHECK-NEXT: return [[RES]] : tensor<*xf32> +} +