Skip to content

Commit

Permalink
【Hackathon 5th No.96】add paddle unstack op (openvinotoolkit#20080)
Browse files Browse the repository at this point in the history
* add paddle mapping

* develop test sampes

* remove redundant code

* revoke code change

* add type mapping

* add test sample

* format code

---------

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
2 people authored and allnes committed Nov 23, 2023
1 parent c0f3d67 commit be22b7d
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/frontends/paddle/src/op/unstack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"

namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs unstack(const NodeContext& node) {
auto data = node.get_input("X");
auto input_shape = data.get_partial_shape();
PADDLE_OP_CHECK(node, input_shape.rank().is_static(), "rank of input data should be static");
auto dim = node.get_attribute<int32_t>("axis", 0);
if (dim < 0) {
dim = dim + static_cast<int32_t>(input_shape.rank().get_length());
}
auto axis = default_opset::Constant::create(element::i32, {}, {dim});
auto shape = input_shape.get_shape();
auto splits = std::make_shared<default_opset::Split>(data, axis, shape.at(dim));
auto split_outputs = splits->outputs();
NamedOutputs named_outputs;
auto out_names = node.get_output_names();
auto it = std::find(out_names.begin(), out_names.end(), "Y");
PADDLE_OP_CHECK(node, it != out_names.end(), "Expected output not found");
for (const auto& split_output : split_outputs) {
named_outputs[*it].push_back(std::make_shared<default_opset::Squeeze>(split_output, axis));
}
return named_outputs;
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/paddle/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ OP_CONVERTER(top_k_v2);
OP_CONVERTER(transpose2);
OP_CONVERTER(trilinear_interp_v2);
OP_CONVERTER(unsqueeze);
OP_CONVERTER(unstack);
OP_CONVERTER(where);
OP_CONVERTER(while_);
OP_CONVERTER(write_to_array);
Expand Down Expand Up @@ -249,6 +250,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"transpose2", op::transpose2},
{"trilinear_interp_v2", op::trilinear_interp_v2},
{"unsqueeze2", op::unsqueeze},
{"unstack", op::unstack},
{"where", op::where},
{"while", op::while_},
{"write_to_array", op::write_to_array},
Expand Down
5 changes: 5 additions & 0 deletions src/frontends/paddle/tests/op_fuzzy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,11 @@ static const std::vector<std::string> models{
std::string("trilinear_upsample_scales2/trilinear_upsample_scales2.pdmodel"),
std::string("trilinear_upsample_true_0/trilinear_upsample_true_0.pdmodel"),
std::string("unsqueeze"),
std::string("unstack_1"),
std::string("unstack_2"),
std::string("unstack_3"),
std::string("unstack_4"),
std::string("unstack_5"),
std::string("where_1"),
std::string("where_2"),
std::string("where_3"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

#
# unstack paddle model generator
#
import paddle
import numpy as np
from save_model import saveModel
import sys


def unstack(name: str, x, axis):
paddle.enable_static()

with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
x_node = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
out = paddle.unstack(x_node, axis) if axis is not None else paddle.unstack(x_node)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
outs = exe.run(feed={"x": x}, fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=out, inputs=[x], outputs=outs, target_dir=sys.argv[1])

return outs


def main():
dtype = np.float32
x = np.random.randn(2, 3, 4).astype(dtype)
unstack(name='unstack_1', x=x, axis=0)

dtype = np.int32
x = np.random.randn(2, 3, 4).astype(dtype)
unstack(name='unstack_2', x=x, axis=1)

dtype = np.int64
x = np.random.randn(3, 4).astype(dtype)
unstack(name='unstack_3', x=x, axis=-1)
unstack(name='unstack_4', x=x, axis=None)

x = np.random.randn(2, 1, 4).astype(dtype)
unstack(name='unstack_5', x=x, axis=0)

if __name__ == "__main__":
main()

0 comments on commit be22b7d

Please sign in to comment.