Skip to content

Commit

Permalink
support folding onnx>2GB with onnxruntime (#528)
Browse files Browse the repository at this point in the history
* add round

* avoid create large constant when converting onnx, some other fix

* port nofuse flag, fix weight path of optimized onnx model

* fix
  • Loading branch information
mzmssg authored Oct 30, 2023
1 parent 35e1a76 commit 56f3ab5
Show file tree
Hide file tree
Showing 17 changed files with 225 additions and 41 deletions.
4 changes: 4 additions & 0 deletions models/pytorch2onnx/ort_run_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def check_shape(shape):
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
if args.optimized_model_filepath != '':
sess_options.optimized_model_filepath = args.optimized_model_filepath
sess_options.add_session_config_entry(
"session.optimized_model_external_initializers_file_name", os.path.basename(args.optimized_model_filepath) + ".data"
)
sess_options.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100")

for k, v in args.symbolic_dims.items():
sess_options.add_free_dimension_override_by_name(k, int(v))
Expand Down
1 change: 1 addition & 0 deletions src/nnfusion/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
#include "nnfusion/core/operators/op_define/result.hpp"
#include "nnfusion/core/operators/op_define/reverse.hpp"
#include "nnfusion/core/operators/op_define/reverse_sequence.hpp"
#include "nnfusion/core/operators/op_define/round.hpp"
#include "nnfusion/core/operators/op_define/rsqrt.hpp"
#include "nnfusion/core/operators/op_define/select.hpp"
#include "nnfusion/core/operators/op_define/select_and_scatter.hpp"
Expand Down
4 changes: 2 additions & 2 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/gather_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ LanguageUnit_p cuda::Gather1D::emit_function_body()
}

lu << "int64_t gather_i = __ldg(indices + indices_i);\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n";
lu << "if (gather_i >= " << gather_dim_size << ")\n"
<< " out[i] = 0;\n"
<< "else\n";
Expand Down Expand Up @@ -194,7 +194,7 @@ LanguageUnit_p cuda::Gather1DGrad::emit_function_body()
}

lu << "int64_t gather_i = __ldg(indices + indices_i);\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n";
lu << "if (gather_i < " << gather_dim_size << ")\n";
lu.block_begin();
{
Expand Down
1 change: 1 addition & 0 deletions src/nnfusion/core/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(SRC
op_define/result.cpp
op_define/reverse_sequence.cpp
op_define/reverse.cpp
op_define/round.cpp
op_define/rsqrt.cpp
op_define/select_and_scatter.cpp
op_define/select.cpp
Expand Down
7 changes: 6 additions & 1 deletion src/nnfusion/core/operators/generic_op/generic_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,15 @@ namespace nnfusion
{
config[alias_name + "_dtype"] = "int64";
}
else if (d_type == element::u8)
{
// hack!!!
config[alias_name + "_dtype"] = "int8";
}
else
{
NNFUSION_CHECK_FAIL()
<< "Unhandled type: " << d_type
<< "Unhandled type for " << input_name << ": " << d_type
<< ", antares currently supports int8/16/32/64, float16/32/64";
}
auto shape = tensor->get_shape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ static const std::unordered_map<std::string, element_op> ElementOpMap = {
{"Sin", element_op("sin", "")},
{"Sinh", element_op("sinh", "")},
{"Sqrt", element_op("sqrt", "")},
{"Round", element_op("round", "x0.call(`round`)")},
{"Rsqrt", element_op("rsqrt", "")},
{"Tan", element_op("tan", "")},
{"Tanh", element_op("tanh", "")},
Expand Down Expand Up @@ -196,6 +197,7 @@ REGISTER_ELEM_OP(Relu)
REGISTER_ELEM_OP(Relu6)
REGISTER_ELEM_OP(ReluBackprop)
REGISTER_ELEM_OP(Relu6Backprop)
REGISTER_ELEM_OP(Round)
REGISTER_ELEM_OP(Sigmoid)
REGISTER_ELEM_OP(SigmoidBackprop)
REGISTER_ELEM_OP(Equal)
Expand Down
21 changes: 10 additions & 11 deletions src/nnfusion/core/operators/generic_op/generic_op_define/Trilu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@

REGISTER_OP(Trilu)
.infershape([](std::shared_ptr<graph::GNode> curr) -> void {
curr->set_output_type_and_shape(0, curr->get_input_element_type(0), curr->get_input_shape(0));
})
curr->set_output_type_and_shape(
0, curr->get_input_element_type(0), curr->get_input_shape(0));
})
.translate_v2([](std::shared_ptr<graph::GNode> curr) -> std::string {
auto input_shape_0 = curr->get_input_shape(0);
assert(input_shape_0.size() >= 2);
std::string k_str = "";
if(curr->get_input_size() == 2)
k_str = "+ input1[0]";
if (curr->get_input_size() == 2)
k_str = "+ input1[0]";
auto op = static_pointer_cast<nnfusion::op::GenericOp>(curr->get_op_ptr());
auto& cfg = op->localOpConfig.getRoot();
bool upper = cfg["upper"].is_null()?true:int64_t(cfg["upper"])!=0;
bool upper = cfg["upper"].is_null() ? true : int64_t(cfg["upper"]) != 0;
auto input_layout = op::create_layout_from_dims(input_shape_0);
auto dim_a = input_layout[input_layout.size() - 2];
auto dim_b = input_layout[input_layout.size() - 1];
Expand All @@ -28,13 +29,11 @@ REGISTER_OP(Trilu)
element::Type::nnfusion_element_type_to_dtype_string(curr->get_element_type(), dtype);
NNFUSION_CHECK(ret);

std::string condition = upper?dim_b+">="+dim_a+k_str:dim_a+k_str+">="+dim_b;
std::string condition = upper ? dim_b + ">=" + dim_a + k_str : dim_a + k_str + ">=" + dim_b;

auto expression = op::create_code_from_template(
"@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, const(0).cast(`@dtype@`));", {
{"input_layout", join(input_layout)},
{"condition", condition},
{"dtype", dtype}
});
"@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, "
"const(0).cast(`@dtype@`));",
{{"input_layout", join(input_layout)}, {"condition", condition}, {"dtype", dtype}});
return expression;
});
4 changes: 2 additions & 2 deletions src/nnfusion/core/operators/op_define/fused.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace nnfusion
std::shared_ptr<graph::GNode> fused_node);
std::string get_fused_ir2() { return fused_op_ir2; };
std::string get_plan_rule();

bool get_is_memcpy() { return is_memcpy; }
protected:
void assemble_inputs_and_outputs();

Expand All @@ -41,4 +41,4 @@ namespace nnfusion
bool is_memcpy;
};
}
}
}
26 changes: 26 additions & 0 deletions src/nnfusion/core/operators/op_define/round.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

// Microsoft (c) 2020, NNFusion Team

#include "round.hpp"

using namespace nnfusion::op;

Round::Round()
: ElementwiseArithmetic("Round")
{
}
35 changes: 35 additions & 0 deletions src/nnfusion/core/operators/op_define/round.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

// Microsoft (c) 2020, NNFusion Team

#pragma once

#include "nnfusion/core/operators/util/elementwise_arithmetic.hpp"

namespace nnfusion
{
namespace op
{
/// \brief Elementwise cosine operation.
class Round : public ElementwiseArithmetic
{
public:
/// \brief Constructs a round operation.
Round();
};
}
}
4 changes: 2 additions & 2 deletions src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,8 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru
lu << "Debug(\"" << node_name << ", " << out_name << member_name << "_f32\", "
<< "fp32tensors, \"" << join(kernel->m_context->input_names) << "\", "
<< kernel->m_context->outputs[i]->size(false) << ");\n";
lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, "
<< max_tensor_size <<"));\n";
lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, " << max_tensor_size
<< "));\n";
}
else if (element::get_backend_cstring(
kernel->m_context->outputs[i]->get_element_type()) == "float")
Expand Down
112 changes: 102 additions & 10 deletions src/nnfusion/engine/pass/graph/register_fusion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using namespace nnfusion::kernels;

DEFINE_string(ftune_output_file, "", "the output json file path");
DEFINE_string(ftune_input_file, "", "the input json file path");
DEFINE_bool(fnofuse, false, "Disable element-wise fusion");
DEFINE_string(ffusion_skiplist, "", "List of op types that skips in fusion");
DECLARE_string(fdefault_device);

Expand Down Expand Up @@ -84,6 +85,14 @@ namespace
});
return nodes;
}

string ir_add_tag(const string& ir, const string& tag)
{
if (ir.find("## @:") != string::npos)
return ir + "|" + tag;
else
return ir + "## @: " + tag;
}
}

class RegisterFusionOptimizer
Expand Down Expand Up @@ -138,11 +147,13 @@ class RegisterFusionOptimizer
fuse_from_node(tnode, true);
}
}
inline_lightweighted_ops();
auto groups = extract_fusion_group();
for (auto group : groups)
{
insert_fuse_group(group);
}
if (!FLAGS_fnofuse)
for (auto group : groups)
{
insert_fuse_group(group);
}
auto nodes = nlohmann::json().array();
for (auto& node : find_topo_sort_priority(m_graph))
{
Expand All @@ -151,10 +162,12 @@ class RegisterFusionOptimizer
auto str = nnfusion::op::get_translation_v2(node);
if (skip_ops.count(node->get_op_type()))
{
if (str.find("## @:") != string::npos)
str += "|skip";
else
str += "## @: skip";
str = ir_add_tag(str, "skip");
}
if (node->get_op_type() == "Fused" &&
std::dynamic_pointer_cast<op::Fused>(node->get_op_ptr())->get_is_memcpy())
{
str = ir_add_tag(str, "memcpy");
}
auto edge = nlohmann::json().array();
for (auto& e : node->get_in_edges())
Expand All @@ -173,7 +186,7 @@ class RegisterFusionOptimizer
}

private:
vector<shared_ptr<FuseGroup>> extract_fusion_group()
vector<shared_ptr<FuseGroup>> extract_fusion_group() const
{
unordered_map<int, shared_ptr<FuseGroup>> groups;
vector<shared_ptr<FuseGroup>> result;
Expand All @@ -195,6 +208,85 @@ class RegisterFusionOptimizer
return result;
}

bool is_lightweighted_op(const shared_ptr<GNode>& node)
{
auto type = node->get_op_type();
if (type == "Slice" || type == "Broadcast")
return true;
if (type == "Reshape")
{
auto op = std::dynamic_pointer_cast<op::Reshape>(node->get_op_ptr());
auto order = op->get_input_order();
if (order.empty())
return true;

bool is_lower_dim_kept = order.back() == order.size() - 1;
return is_lower_dim_kept;
}
return false;
}

void inline_lightweighted_ops()
{
// Iterate over all independent groups
// inline first group into second if:
// 1. first group has one output
// 2. first group are all light weighted ops
// 3. all ops not in skip lists
unordered_map<int, shared_ptr<FuseGroup>> map;
vector<shared_ptr<FuseGroup>> groups;
for (auto& tnode : node_list_)
{
if (tnode->node_->get_op_ptr()->is_tensor_op())
continue;
if (tnode->group_id_ < 0)
{
auto f = make_shared<FuseGroup>();
f->nodes.insert(tnode->node_);
groups.push_back(f);
}
else
{
if (!map.count(tnode->group_id_))
{
map[tnode->group_id_] = make_shared<FuseGroup>();
}
map[tnode->group_id_]->nodes.insert(tnode->node_);
}
}
for (auto& kv : map)
groups.push_back(kv.second);

for (auto& group : groups)
{
bool group_is_lightweighted = true;
unordered_set<shared_ptr<GNode>> group_outputs;
for (auto& node : group->nodes)
{
group_is_lightweighted &= is_lightweighted_op(node);
for (auto& edge : node->get_out_edges())
{
if (!group->nodes.count(edge->get_dst()))
group_outputs.insert(edge->get_dst());
}
}
if (group_outputs.size() == 0)
continue;
auto& output_node = *group_outputs.begin();
auto& tag_output_node = node_map_[output_node];
bool op_skip = skip_ops.count(output_node->get_op_type());
for (auto& node : group->nodes)
op_skip |= skip_ops.count(node->get_op_type());
if (group_is_lightweighted && !op_skip && group_outputs.size() == 1)
{
if (tag_output_node->group_id_ < 0)
tag_output_node->group_id_ = cur_group_++;
for (auto& node : group->nodes)
node_map_[node]->group_id_ = tag_output_node->group_id_;
}
}
}

void insert_fuse_group(shared_ptr<FuseGroup> group)
{
// get a meaningful name
Expand Down Expand Up @@ -453,4 +545,4 @@ bool RegisterFusionPass::run_on_graph(std::shared_ptr<Graph>& graph)
applier.apply(FLAGS_ftune_input_file);
NNFUSION_LOG(INFO) << "RegisterFusionPass Done";
return true;
}
}
Loading

0 comments on commit 56f3ab5

Please sign in to comment.