forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TF FE] Support Complex Tensors (openvinotoolkit#20860)
* [TF FE] Support complex tensors Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Align output type for Real and Imag operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update decoding complex types * Add support for ComplexAbs, FFT and IFFT operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Correct axes based on a number of inner-most dimensions * Add layer tests Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update supported ops documentation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add a comment for ComplexTypeMark Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
- Loading branch information
Showing
17 changed files
with
804 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/type/element_type.hpp" | ||
#include "openvino/op/util/framework_node.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace tensorflow { | ||
|
||
// ComplexTypeMark serves to mark places that require complex type propagation | ||
// that means to represent native complex type with simulating floating-point tensor | ||
// that has one extra dimension to concatenate real and imaginary parts of complex tensor. | ||
// For example, a tensor of complex type with shape [N1, N2, ..., Nk] will be transformed | ||
// into a floating-point tensor [N1, N2, ..., Nk, 2] | ||
// where a slice with index [..., 0] represents a real part and | ||
// a slice with index [..., 1] represents a imaginary part. | ||
class ComplexTypeMark : public ov::op::util::FrameworkNode { | ||
public: | ||
OPENVINO_OP("ComplexTypeMark", "util", ov::op::util::FrameworkNode); | ||
|
||
ComplexTypeMark(const ov::Output<ov::Node>& input, const ov::element::Type& complex_part_type) | ||
: ov::op::util::FrameworkNode(ov::OutputVector{input}, 1), | ||
m_complex_part_type(complex_part_type) { | ||
validate_and_infer_types(); | ||
} | ||
|
||
void validate_and_infer_types() override { | ||
set_output_type(0, ov::element::dynamic, PartialShape::dynamic()); | ||
} | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override { | ||
auto complex_type_mark = std::make_shared<ComplexTypeMark>(inputs[0], m_complex_part_type); | ||
complex_type_mark->set_attrs(get_attrs()); | ||
return complex_type_mark; | ||
} | ||
|
||
ov::element::Type get_complex_part_type() const { | ||
return m_complex_part_type; | ||
} | ||
|
||
private: | ||
ov::element::Type m_complex_part_type; | ||
}; | ||
|
||
} // namespace tensorflow | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.