diff --git a/tests/test_backend.py b/tests/test_backend.py index 144f7ca76..c0e822a77 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -4698,6 +4698,17 @@ def func(x, y, z): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}) + @check_opset_min_version(16, "ScatterND") + def test_scatternd_add(self): + x_val = np.array([10, 20, 30, 40], dtype=np.int32).reshape((4)) + y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1)) + z_val = np.array([20, 30], dtype=np.int32).reshape((2)) + + def func(x, y, z): + x_ = tf.tensor_scatter_nd_add(x, y, z) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}) + @check_opset_min_version(11, "ScatterND") def test_scatternd_1d(self): x_val = np.array([4, 3, 1, 7], dtype=np.int32).reshape((4, 1)) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 49564223e..9ad2b513e 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -655,6 +655,16 @@ def version_11(cls, ctx, node, **kwargs): ctx.replace_inputs(node, [node.input[2], node.input[0], node.input[1]]) +@tf_op("TensorScatterAdd", onnx_op="ScatterND") +class TensorScatterAdd: + @classmethod + def version_16(cls, ctx, node, **kwargs): + # indicies input must be int64 in ONNX. + if ctx.get_dtype(node.input[1]) != TensorProto.INT64: + ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64) + node.set_attr("reduction", 'add') + + @tf_op("TensorScatterUpdate", onnx_op="ScatterND") class TensorScatterUpdate: @classmethod