diff --git a/src/frontends/tensorflow_common/src/op/scatter_nd.cpp b/src/frontends/tensorflow_common/src/op/scatter_nd.cpp index 894722be163545..4cbb5982f44691 100644 --- a/src/frontends/tensorflow_common/src/op/scatter_nd.cpp +++ b/src/frontends/tensorflow_common/src/op/scatter_nd.cpp @@ -3,7 +3,9 @@ // #include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" #include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/scatter_nd_update.hpp" #include "utils.hpp" @@ -15,12 +17,33 @@ namespace frontend { namespace tensorflow { namespace op { OutputVector translate_scatter_nd_op(const NodeContext& node) { - default_op_checks(node, 3, {"ScatterNd", "SCATTER_ND"}); + default_op_checks(node, 3, {"ScatterNd", "SCATTER_ND"}, true); auto input_indices = node.get_input(0); auto updates = node.get_input(1); auto shape = node.get_input(2); + auto complex_type_mark_updates = as_type_ptr(updates.get_node_shared_ptr()); + auto zero_scalar = create_same_type_const(updates, 0); - auto input_data = create_same_type_const(updates, vector{0}, Shape{1}); + if (complex_type_mark_updates) { + updates = complex_type_mark_updates->input_value(0); + // Add two auxiliary dimensions to the shape tensor + auto shape_of_op = make_shared(updates); + auto shape_dims = make_shared(shape_of_op); + auto aux_shape = create_same_type_const(shape, std::vector{2}, Shape{1}); + auto updated_shape = make_shared(OutputVector{aux_shape, shape_dims}, 0); + + auto input_data = zero_scalar; + auto broadcast = make_shared(input_data, updated_shape); + + auto scatter_nd = make_shared(broadcast, input_indices, updates); + set_node_name(node.get_name(), scatter_nd); + + auto complex_scatter_nd = + make_shared(scatter_nd, complex_type_mark_updates->get_complex_part_type()); + return {complex_scatter_nd}; + } + + auto input_data = zero_scalar; auto broadcast = make_shared(input_data, shape); auto scatter_nd = make_shared(broadcast, input_indices, updates); set_node_name(node.get_name(), scatter_nd); diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ScatterND.py b/tests/layer_tests/tensorflow_tests/test_tf_ScatterND.py index 0401c074229e7b..cb2f78e0b6970a 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_ScatterND.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_ScatterND.py @@ -5,7 +5,7 @@ import pytest from common.tf_layer_test_class import CommonTFLayerTest - +import numpy as np class TestTFScatterND(CommonTFLayerTest): def create_tf_scatternd_placeholder_const_net(self, x_shape, indices, updates, ir_version, @@ -73,3 +73,65 @@ def test_tf_scatter_nd(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend=use_legacy_frontend), ie_device, precision, temp_dir=temp_dir, ir_version=ir_version, use_legacy_frontend=use_legacy_frontend, **params) + +class TestComplexScatterND(CommonTFLayerTest): + def create_complex_scatter_nd_net(self, x_shape, indices, updates, ir_version, + use_legacy_frontend): + import tensorflow as tf + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, x_shape, 'Input') + tf_indices = tf.compat.v1.placeholder(np.int32, [None], 'indices') + updates_real = tf.compat.v1.placeholder(np.float32, [None], 'updates_real') + updates_imag = tf.compat.v1.placeholder(np.float32, [None], 'updates_imag') + updates = tf.raw_ops.Complex(real=updates_real,imag=updates_imag) + + scatter_nd = tf.raw_ops.ScatterNd(indices, updates, tf.shape(x), name="Operation") + res = tf.add(x_shape, scatter_nd, name="Operation") + real = tf.raw_ops.Real(input=res) + img = tf.raw_ops.Imag(input=res) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + def prepare_input(self, inputs_info): + rng = np.random.default_rng() + + assert 'indices' in inputs_info + assert 'updates_real' in inputs_info + assert 'updates_imag' in inputs_info + assert 'x_shape' in inputs_info + + indices_shape = inputs_info['indices'] + updates_real_shape = inputs_info['updates_real'] + updates_imag_shape = inputs_info['updates_imag'] + x_shape = inputs_info['x_shape'] + + inputs_data = {} + inputs_data['indices'] = rng.integers(0, 10, indices_shape).astype(np.int32) # Example range (0, 10), adjust as needed + inputs_data['updates_real'] = 4 * rng.random(updates_real_shape).astype(np.float32) - 2 + inputs_data['updates_imag'] = 4 * rng.random(updates_imag_shape).astype(np.float32) - 2 + inputs_data['x_shape'] = rng.integers(0, 10, indices_shape).astype(np.int32) + + return inputs_data + + test_data_basic = [ + dict(input_shape=[]), + dict(input_shape=[2]), + dict(input_shape=[1, 3]), + dict(input_shape=[2, 3, 4]), + dict(input_shape=[3, 4, 5, 6]), + ] + + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tf_scatter_nd(self, params, ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + self._test(*self.create_complex_scatter_nd_net(**params, ir_version=ir_version, + use_legacy_frontend=use_legacy_frontend), + ie_device, precision, temp_dir=temp_dir, ir_version=ir_version, + use_legacy_frontend=use_legacy_frontend, **params) \ No newline at end of file