From ae4c39ed3bdab7edf487d73d5892a573684d1d6a Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Thu, 16 Nov 2023 08:53:30 +0100 Subject: [PATCH] Fix unsupported ops TF 2.14.0: OnesLike (#2270) * add OnesLike handler * add tests for OnesLike --------- Signed-off-by: Alexander Gerstenberger Co-authored-by: Alexander Gerstenberger Co-authored-by: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> --- tests/test_backend.py | 10 ++++++ tf2onnx/onnx_opset/generator.py | 55 ++++++++++++++++++++++----------- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 020416971..740d6c584 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -4113,6 +4113,16 @@ def func(x, y): self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True, graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False)) + def test_ones_like(self): + input_x = np.random.random_sample([3, 16, 16]).astype(np.float32) + input_y = np.array([16, 16, 3]).astype(np.int64) + + def func(x, y): + z = tf.reshape(x, y) + return tf.ones_like(z, name=_TFOUTPUT) + + self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y}) + @check_opset_min_version(9, "is_nan") def test_isnan(self): # only compatible with dtype `float32` diff --git a/tf2onnx/onnx_opset/generator.py b/tf2onnx/onnx_opset/generator.py index f6e0409cb..01cfea523 100644 --- a/tf2onnx/onnx_opset/generator.py +++ b/tf2onnx/onnx_opset/generator.py @@ -227,31 +227,50 @@ def version_7(cls, ctx, node, **kwargs): ctx.remove_input(node, node.input[1], 1) +def _const_like_version_1(ctx, node, value): + shapes = node.output_shapes + dtypes = node.output_dtypes + ctx.remove_node(node.name) + casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64}) + const_value = ctx.make_const(utils.make_name("value"), np.array(value).astype(np.int64)) + mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_value.output[0]]) + ctx.make_node("Cast", inputs=[mul_node.output[0]], + attr={'to': dtypes[0]}, + name=node.name, outputs=node.output, + shapes=shapes, dtypes=dtypes) + + +def _const_like_version_9(ctx, node, value): + dtypes = node.output_dtypes + ctx.remove_node(node.name) + shape = ctx.make_node("Shape", node.input).output[0] + value_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[value]) + ctx.make_node("ConstantOfShape", inputs=[shape], + attr={'value': value_tensor}, + name=node.name, outputs=node.output, + dtypes=dtypes) + + @tf_op("ZerosLike") class ZerosLike: @classmethod def version_1(cls, ctx, node, **kwargs): - shapes = node.output_shapes - dtypes = node.output_dtypes - ctx.remove_node(node.name) - casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64}) - const_zero = ctx.make_const(utils.make_name("zero"), np.array(0).astype(np.int64)) - mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_zero.output[0]]) - ctx.make_node("Cast", inputs=[mul_node.output[0]], - attr={'to': dtypes[0]}, - name=node.name, outputs=node.output, - shapes=shapes, dtypes=dtypes) + _const_like_version_1(ctx, node, 0) @classmethod def version_9(cls, ctx, node, **kwargs): - dtypes = node.output_dtypes - ctx.remove_node(node.name) - shape = ctx.make_node("Shape", node.input).output[0] - zero_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[0]) - ctx.make_node("ConstantOfShape", inputs=[shape], - attr={'value': zero_tensor}, - name=node.name, outputs=node.output, - dtypes=dtypes) + _const_like_version_9(ctx, node, 0) + + +@tf_op("OnesLike") +class OnesLike: + @classmethod + def version_1(cls, ctx, node, **kwargs): + _const_like_version_1(ctx, node, 1) + + @classmethod + def version_9(cls, ctx, node, **kwargs): + _const_like_version_9(ctx, node, 1) @tf_op(["IteratorV2", "FIFOQueueV2"])