Skip to content

Commit

Permalink
Add TensorScatterAdd op for opset 16 (#1949)
Browse files Browse the repository at this point in the history
* add TensorScatterAdd op conversion in opset 16

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>

* add unit test

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>

* fix pylint name

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>

* fix typo and test op check

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>
Co-authored-by: Jay Zhang <jiz@microsoft.com>

Co-authored-by: Jay Zhang <jiz@microsoft.com>
  • Loading branch information
hwangdeyu and fatcat-z authored May 27, 2022
1 parent 6f5a673 commit 29b76df
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4698,6 +4698,17 @@ def func(x, y, z):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})

@check_opset_min_version(16, "ScatterND")
def test_scatternd_add(self):
x_val = np.array([10, 20, 30, 40], dtype=np.int32).reshape((4))
y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))
z_val = np.array([20, 30], dtype=np.int32).reshape((2))

def func(x, y, z):
x_ = tf.tensor_scatter_nd_add(x, y, z)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})

@check_opset_min_version(11, "ScatterND")
def test_scatternd_1d(self):
x_val = np.array([4, 3, 1, 7], dtype=np.int32).reshape((4, 1))
Expand Down
10 changes: 10 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,16 @@ def version_11(cls, ctx, node, **kwargs):
ctx.replace_inputs(node, [node.input[2], node.input[0], node.input[1]])


@tf_op("TensorScatterAdd", onnx_op="ScatterND")
class TensorScatterAdd:
@classmethod
def version_16(cls, ctx, node, **kwargs):
# indicies input must be int64 in ONNX.
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
node.set_attr("reduction", 'add')


@tf_op("TensorScatterUpdate", onnx_op="ScatterND")
class TensorScatterUpdate:
@classmethod
Expand Down

0 comments on commit 29b76df

Please sign in to comment.