diff --git a/support_status.md b/support_status.md index bf3baeb7b..37fc78cfe 100644 --- a/support_status.md +++ b/support_status.md @@ -254,6 +254,9 @@ | TensorListSetItem | 7 ~ 17 | | TensorListStack | 7 ~ 17 | | TensorScatterAdd | 16 ~ 17 | +| TensorScatterMax | 16 ~ 17 | +| TensorScatterMin | 16 ~ 17 | +| TensorScatterSub | 16 ~ 17 | | TensorScatterUpdate | 11 ~ 17 | | Tile | 1 ~ 17 | | TopKV2 | 1 ~ 17 | diff --git a/tests/test_backend.py b/tests/test_backend.py index c1bc9c052..b9edda714 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -6102,5 +6102,50 @@ def func(x): x_val = make_xval([2, 3]) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_tf_min_version("2.3.0") + @check_opset_min_version(16, "ScatterND") + @skip_tfjs("not supported in tfjs") + def test_tensor_scatter_max(self): + def func(tensor, indices, updates): + op = tf.tensor_scatter_nd_max(tensor, indices, updates) + return tf.identity(op, name=_TFOUTPUT) + + tensor_val = make_xval([3, 4, 5]) + indices_val = np.array([[2, 3], [0, 1]], np.int32) + indices64_val = indices_val.astype(np.int64) + updates_val = make_xval([2, 5]) + 3 + self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val}) + self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val}) + + @check_tf_min_version("2.3.0") + @check_opset_min_version(16, "ScatterND") + @skip_tfjs("not supported in tfjs") + def test_tensor_scatter_min(self): + def func(tensor, indices, updates): + op = tf.tensor_scatter_nd_min(tensor, indices, updates) + return tf.identity(op, name=_TFOUTPUT) + + tensor_val = make_xval([3, 4, 5]) + indices_val = np.array([[2, 3], [0, 1]], np.int32) + indices64_val = indices_val.astype(np.int64) + updates_val = make_xval([2, 5]) + 3 + self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val}) + self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val}) + + @check_tf_min_version("1.12.1") + @check_opset_min_version(16, "ScatterND") + @skip_tfjs("not supported in tfjs") + def test_tensor_scatter_sub(self): + def func(tensor, indices, updates): + op = tf.tensor_scatter_nd_sub(tensor, indices, updates) + return tf.identity(op, name=_TFOUTPUT) + + tensor_val = make_xval([3, 4, 5]) + indices_val = np.array([[2, 3], [0, 1]], np.int32) + indices64_val = indices_val.astype(np.int64) + updates_val = make_xval([2, 5]) + 3 + self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val}) + self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val}) + if __name__ == '__main__': unittest_main() diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index bb9b0cb52..df6ac9420 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -670,6 +670,37 @@ def version_16(cls, ctx, node, **kwargs): node.set_attr("reduction", 'add') +@tf_op("TensorScatterMax", onnx_op="ScatterND") +class TensorScatterMax: + @classmethod + def version_16(cls, ctx, node, **kwargs): + # indices input must be int64 in ONNX. + if ctx.get_dtype(node.input[1]) != TensorProto.INT64: + ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64) + node.set_attr("reduction", 'max') + + +@tf_op("TensorScatterMin", onnx_op="ScatterND") +class TensorScatterMin: + @classmethod + def version_16(cls, ctx, node, **kwargs): + # indices input must be int64 in ONNX. + if ctx.get_dtype(node.input[1]) != TensorProto.INT64: + ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64) + node.set_attr("reduction", 'min') + + +@tf_op("TensorScatterSub", onnx_op="ScatterND") +class TensorScatterSub: + @classmethod + def version_16(cls, ctx, node, **kwargs): + # indices input must be int64 in ONNX. + if ctx.get_dtype(node.input[1]) != TensorProto.INT64: + ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64) + ctx.insert_new_node_on_input(node, "Neg", node.input[2]) + node.set_attr("reduction", 'add') + + @tf_op("TensorScatterUpdate", onnx_op="ScatterND") class TensorScatterUpdate: @classmethod