diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index ec2fbe41b89..ecd8cce5377 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -27,7 +27,7 @@ op_dialect_version_map_["CastMap"] = {1}; op_dialect_version_map_["CategoryMapper"] = {1}; op_dialect_version_map_["Ceil"] = {13}; op_dialect_version_map_["Celu"] = {12}; -op_dialect_version_map_["Clip"] = {13}; +op_dialect_version_map_["Clip"] = {13, 12, 11, 6}; op_dialect_version_map_["Compress"] = {11}; op_dialect_version_map_["Concat"] = {13}; op_dialect_version_map_["ConcatFromSequence"] = {11}; @@ -224,6 +224,12 @@ import_handler_map_["Celu"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["Clip"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ClipV12"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ClipV11"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ClipV6"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["Compress"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["Concat"] = diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 00ce5fed197..aa944d64a1b 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -4688,6 +4688,9 @@ NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV2Op); NOT_IMPLEMENTED_INFERSHAPE(ONNXPadV11Op); NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV11Op); NOT_IMPLEMENTED_INFERSHAPE(ONNXResizeV10Op); +NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV6Op); +NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV11Op); +NOT_IMPLEMENTED_INFERSHAPE(ONNXClipV12Op); //===----------------------------------------------------------------------===// // Loop diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 033be97644e..ef8f44f5948 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -586,6 +586,81 @@ def ONNXClipOp:ONNX_Op<"Clip", }]; } +def ONNXClipV12Op:ONNX_Op<"ClipV12", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX Clip operation"; + let description = [{ + "Clip operator limits the given input within an interval. The interval is" + "specified by the inputs 'min' and 'max'. They default to" + "numeric_limits::lowest() and numeric_limits::max(), respectively." + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$input, + AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$min, + AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType, NoneType]>:$max); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType, NoneType]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; +} + +def ONNXClipV11Op:ONNX_Op<"ClipV11", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX Clip operation"; + let description = [{ + "Clip operator limits the given input within an interval. The interval is" + "specified by the inputs 'min' and 'max'. They default to" + "numeric_limits::lowest() and numeric_limits::max(), respectively." + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$input, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType]>:$min, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType, NoneType]>:$max); + let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef, NoneType, NoneType]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; +} + +def ONNXClipV6Op:ONNX_Op<"ClipV6", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX Clip operation"; + let description = [{ + "Clip operator limits the given input within an interval. The interval is" + "specified with arguments 'min' and 'max'. They default to" + "numeric_limits::lowest() and numeric_limits::max() respectively." + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$input, + DefaultValuedAttr:$max, + DefaultValuedAttr:$min); + let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; +} + def ONNXCompressOp:ONNX_Op<"Compress", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Compress operation"; diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Transform/ONNX/Decompose.cpp index 9bc95c91911..95ddebdb854 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Transform/ONNX/Decompose.cpp @@ -60,6 +60,27 @@ DenseElementsAttr createDenseArrayAttr( llvm_unreachable("unexpected attribute type"); } +/// Create an Scalar DenseElementsAttr from FloatAttr or IntergerAttr. +/// This is used to create an ONNXConstant of rank 0, e.g. tensor. +DenseElementsAttr createScalarDenseAttr( + PatternRewriter &rewriter, Attribute attr) { + if (attr.dyn_cast()) { + mlir::Type elementType = rewriter.getF32Type(); + SmallVector wrapper; + wrapper.emplace_back(attr.cast().getValueAsDouble()); + return DenseElementsAttr::get( + RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper)); + } + if (attr.dyn_cast()) { + mlir::Type elementType = rewriter.getIntegerType(64); + SmallVector wrapper; + wrapper.emplace_back(attr.cast().getInt()); + return DenseElementsAttr::get( + RankedTensorType::get({}, elementType), llvm::makeArrayRef(wrapper)); + } + llvm_unreachable("unexpected attribute type"); +} + ConstantOp createUnitConstant(PatternRewriter &rewriter, Location loc) { return rewriter.create(loc, rewriter.getUnitAttr()); } @@ -124,21 +145,24 @@ void DecomposeONNXToONNXPass::runOnFunction() { // These ops will be decomposed into other ONNX ops. Hence, they will not be // available after this pass. + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); RewritePatternSet patterns(context); populateWithGenerated(patterns); diff --git a/src/Transform/ONNX/Decompose.td b/src/Transform/ONNX/Decompose.td index 41c30c2764b..07b88eb36ed 100644 --- a/src/Transform/ONNX/Decompose.td +++ b/src/Transform/ONNX/Decompose.td @@ -56,6 +56,11 @@ def GetNullStringAttr : def CreateUnitConstant : NativeCodeCall<"createUnitConstant($_builder, $_loc)">; +// Create a scalar DenseElementsAttr (rank 0) from a single attribute. +// E.g return type is tensor instead of tensor<0xf32> or tensor<1xf32> +def createScalarDenseAttrRank0 + : NativeCodeCall<"createScalarDenseAttr($_builder, $0)">; + // Create a DenseElementsAttr from a single attribute. def createDenseArrayAttrFromSingleAttr : NativeCodeCall<"createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">; @@ -317,4 +322,23 @@ def SequenceConstructPattern1: Pat< $x1) >; +// Express Clip V6 using Clip V11. +def ClipV6Pattern : Pat< + (ONNXClipV6Op $x, $maxAttr, $minAttr), + (ONNXClipV11Op $x, (ONNXConstantOpFromDenseAttr(createScalarDenseAttrRank0 $minAttr)), + (ONNXConstantOpFromDenseAttr(createScalarDenseAttrRank0 $maxAttr))) +>; + +// Express Clip V11 using Clip V12. +def ClipV11Pattern : Pat< + (ONNXClipV11Op $x, $min, $max), + (ONNXClipV12Op $x, $min, $max) +>; + +// Express Clip V12 using Clip V13 (the lastest). +def ClipV12Pattern : Pat< + (ONNXClipV12Op $x, $min, $max), + (ONNXClipOp $x, $min, $max) +>; + #endif // ONNX_DECOMPOSE diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 4d9f50c408f..871ec424e29 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -309,3 +309,17 @@ func @test_seqence_construct_1(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> !o // CHECK: return [[VAR_2_]] : !onnx.Seq> } +// ----- + +func @test_clipv6(%arg0 : tensor<*xf32>) -> () { + %0 = "onnx.ClipV6"(%arg0) {max = 6.000000e+00 : f32, min = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> + return + +// CHECK-LABEL: func @test_clipv6 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<6.000000e+00> : tensor} : () -> tensor +// CHECK: [[VAR_2_:%.+]] = "onnx.Clip"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> +// CHECK: return +// CHECK: } +} diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 256fa707c41..b2f0d491429 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -73,7 +73,7 @@ 'CategoryMapper': [1], 'Ceil': [13], 'Celu': [12], - 'Clip': [13], + 'Clip': [13, 12, 11, 6], 'Compress': [11], 'Concat': [13], 'ConcatFromSequence': [11],