Skip to content

Commit

Permalink
backend: compiler: fix query dynamic outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 authored and TaoLv committed Jan 9, 2023
1 parent 147d9bc commit 8dbca04
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions src/backend/graph_compiler/compiler_partition_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -411,28 +411,17 @@ status_t compiler_compiled_partition_impl_t::query_dynamic_outputs(
trans_func(in_lts[i], ins[idx]);
}
for (size_t i = 0; i < out_lts.size(); ++i) {
size_t lt_id = out_lts[i]->id;
auto found = std::find_if(outputs_.begin(), outputs_.end(),
[&lt_id](const impl::logical_tensor_t &e) {
return e.id == lt_id;
});
if (found == outputs_.end()) continue;
size_t idx = found - outputs_.begin();
trans_func(out_lts[i], outs[idx]);
trans_func(out_lts[i], outs[i]);
}
sc::dynamic_infer_shape_by_graph(
sc_graph_, ins.data(), outs.data(), ins.size(), outs.size());
for (size_t i = 0; i < out_lts.size(); ++i) {
size_t lt_id = out_lts[i]->id;
auto found = std::find_if(outputs_.begin(), outputs_.end(),
[&lt_id](const impl::logical_tensor_t &e) {
return e.id == lt_id;
});
if (found == outputs_.end()) continue;
size_t idx = found - outputs_.begin();
out_lts[i]->ndims = outs[idx]->ndims_;
for (auto d = 0; d < outs[idx]->ndims_; ++d) {
out_lts[i]->dims[d] = outs[idx]->dims_[d];
out_lts[i]->id = outputs_[i].id;
out_lts[i]->data_type = outputs_[i].data_type;
out_lts[i]->property = outputs_[i].property;
out_lts[i]->ndims = outs[i]->ndims_;
for (auto d = 0; d < outs[i]->ndims_; ++d) {
out_lts[i]->dims[d] = outs[i]->dims_[d];
}
// set dense stride here
out_lts[i]->layout_type = impl::layout_type::strided;
Expand Down

0 comments on commit 8dbca04

Please sign in to comment.