diff --git a/models/pytorch2onnx/ort_run_frozen.py b/models/pytorch2onnx/ort_run_frozen.py index ed6d8332e..23263b1dd 100644 --- a/models/pytorch2onnx/ort_run_frozen.py +++ b/models/pytorch2onnx/ort_run_frozen.py @@ -93,6 +93,10 @@ def check_shape(shape): sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if args.optimized_model_filepath != '': sess_options.optimized_model_filepath = args.optimized_model_filepath + sess_options.add_session_config_entry( + "session.optimized_model_external_initializers_file_name", os.path.basename(args.optimized_model_filepath) + ".data" + ) + sess_options.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") for k, v in args.symbolic_dims.items(): sess_options.add_free_dimension_override_by_name(k, int(v)) diff --git a/src/nnfusion/common/common.hpp b/src/nnfusion/common/common.hpp index 8c8ec2b3d..b3a3e3319 100644 --- a/src/nnfusion/common/common.hpp +++ b/src/nnfusion/common/common.hpp @@ -84,6 +84,7 @@ #include "nnfusion/core/operators/op_define/result.hpp" #include "nnfusion/core/operators/op_define/reverse.hpp" #include "nnfusion/core/operators/op_define/reverse_sequence.hpp" +#include "nnfusion/core/operators/op_define/round.hpp" #include "nnfusion/core/operators/op_define/rsqrt.hpp" #include "nnfusion/core/operators/op_define/select.hpp" #include "nnfusion/core/operators/op_define/select_and_scatter.hpp" diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/gather_1d.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/gather_1d.cpp index 9925c7253..0d090c5b1 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/gather_1d.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/gather_1d.cpp @@ -81,7 +81,7 @@ LanguageUnit_p cuda::Gather1D::emit_function_body() } lu << "int64_t gather_i = __ldg(indices + indices_i);\n"; - lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n"; + lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n"; lu << "if (gather_i >= " << gather_dim_size << ")\n" << " out[i] = 0;\n" << "else\n"; @@ -194,7 +194,7 @@ LanguageUnit_p cuda::Gather1DGrad::emit_function_body() } lu << "int64_t gather_i = __ldg(indices + indices_i);\n"; - lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n"; + lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n"; lu << "if (gather_i < " << gather_dim_size << ")\n"; lu.block_begin(); { diff --git a/src/nnfusion/core/operators/CMakeLists.txt b/src/nnfusion/core/operators/CMakeLists.txt index f445946fc..564c27edb 100644 --- a/src/nnfusion/core/operators/CMakeLists.txt +++ b/src/nnfusion/core/operators/CMakeLists.txt @@ -68,6 +68,7 @@ set(SRC op_define/result.cpp op_define/reverse_sequence.cpp op_define/reverse.cpp + op_define/round.cpp op_define/rsqrt.cpp op_define/select_and_scatter.cpp op_define/select.cpp diff --git a/src/nnfusion/core/operators/generic_op/generic_op.hpp b/src/nnfusion/core/operators/generic_op/generic_op.hpp index 17d1c8ca0..765c574d3 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op.hpp +++ b/src/nnfusion/core/operators/generic_op/generic_op.hpp @@ -313,10 +313,15 @@ namespace nnfusion { config[alias_name + "_dtype"] = "int64"; } + else if (d_type == element::u8) + { + // hack!!! + config[alias_name + "_dtype"] = "int8"; + } else { NNFUSION_CHECK_FAIL() - << "Unhandled type: " << d_type + << "Unhandled type for " << input_name << ": " << d_type << ", antares currently supports int8/16/32/64, float16/32/64"; } auto shape = tensor->get_shape(); diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp index f051fccfc..2c7e50301 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Elementwise.cpp @@ -31,6 +31,7 @@ static const std::unordered_map ElementOpMap = { {"Sin", element_op("sin", "")}, {"Sinh", element_op("sinh", "")}, {"Sqrt", element_op("sqrt", "")}, + {"Round", element_op("round", "x0.call(`round`)")}, {"Rsqrt", element_op("rsqrt", "")}, {"Tan", element_op("tan", "")}, {"Tanh", element_op("tanh", "")}, @@ -196,6 +197,7 @@ REGISTER_ELEM_OP(Relu) REGISTER_ELEM_OP(Relu6) REGISTER_ELEM_OP(ReluBackprop) REGISTER_ELEM_OP(Relu6Backprop) +REGISTER_ELEM_OP(Round) REGISTER_ELEM_OP(Sigmoid) REGISTER_ELEM_OP(SigmoidBackprop) REGISTER_ELEM_OP(Equal) diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Trilu.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Trilu.cpp index 81d9ba358..dce6f4bff 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Trilu.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Trilu.cpp @@ -8,17 +8,18 @@ REGISTER_OP(Trilu) .infershape([](std::shared_ptr curr) -> void { - curr->set_output_type_and_shape(0, curr->get_input_element_type(0), curr->get_input_shape(0)); - }) + curr->set_output_type_and_shape( + 0, curr->get_input_element_type(0), curr->get_input_shape(0)); + }) .translate_v2([](std::shared_ptr curr) -> std::string { auto input_shape_0 = curr->get_input_shape(0); assert(input_shape_0.size() >= 2); std::string k_str = ""; - if(curr->get_input_size() == 2) - k_str = "+ input1[0]"; + if (curr->get_input_size() == 2) + k_str = "+ input1[0]"; auto op = static_pointer_cast(curr->get_op_ptr()); auto& cfg = op->localOpConfig.getRoot(); - bool upper = cfg["upper"].is_null()?true:int64_t(cfg["upper"])!=0; + bool upper = cfg["upper"].is_null() ? true : int64_t(cfg["upper"]) != 0; auto input_layout = op::create_layout_from_dims(input_shape_0); auto dim_a = input_layout[input_layout.size() - 2]; auto dim_b = input_layout[input_layout.size() - 1]; @@ -28,13 +29,11 @@ REGISTER_OP(Trilu) element::Type::nnfusion_element_type_to_dtype_string(curr->get_element_type(), dtype); NNFUSION_CHECK(ret); - std::string condition = upper?dim_b+">="+dim_a+k_str:dim_a+k_str+">="+dim_b; + std::string condition = upper ? dim_b + ">=" + dim_a + k_str : dim_a + k_str + ">=" + dim_b; auto expression = op::create_code_from_template( - "@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, const(0).cast(`@dtype@`));", { - {"input_layout", join(input_layout)}, - {"condition", condition}, - {"dtype", dtype} - }); + "@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, " + "const(0).cast(`@dtype@`));", + {{"input_layout", join(input_layout)}, {"condition", condition}, {"dtype", dtype}}); return expression; }); \ No newline at end of file diff --git a/src/nnfusion/core/operators/op_define/fused.hpp b/src/nnfusion/core/operators/op_define/fused.hpp index 433443a70..24a6b9ae4 100644 --- a/src/nnfusion/core/operators/op_define/fused.hpp +++ b/src/nnfusion/core/operators/op_define/fused.hpp @@ -32,7 +32,7 @@ namespace nnfusion std::shared_ptr fused_node); std::string get_fused_ir2() { return fused_op_ir2; }; std::string get_plan_rule(); - + bool get_is_memcpy() { return is_memcpy; } protected: void assemble_inputs_and_outputs(); @@ -41,4 +41,4 @@ namespace nnfusion bool is_memcpy; }; } -} +} \ No newline at end of file diff --git a/src/nnfusion/core/operators/op_define/round.cpp b/src/nnfusion/core/operators/op_define/round.cpp new file mode 100644 index 000000000..c5d78c352 --- /dev/null +++ b/src/nnfusion/core/operators/op_define/round.cpp @@ -0,0 +1,26 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Microsoft (c) 2020, NNFusion Team + +#include "round.hpp" + +using namespace nnfusion::op; + +Round::Round() + : ElementwiseArithmetic("Round") +{ +} \ No newline at end of file diff --git a/src/nnfusion/core/operators/op_define/round.hpp b/src/nnfusion/core/operators/op_define/round.hpp new file mode 100644 index 000000000..6cb6e4361 --- /dev/null +++ b/src/nnfusion/core/operators/op_define/round.hpp @@ -0,0 +1,35 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// Microsoft (c) 2020, NNFusion Team + +#pragma once + +#include "nnfusion/core/operators/util/elementwise_arithmetic.hpp" + +namespace nnfusion +{ + namespace op + { + /// \brief Elementwise cosine operation. + class Round : public ElementwiseArithmetic + { + public: + /// \brief Constructs a round operation. + Round(); + }; + } +} diff --git a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp index 8b5539f44..6d1e94764 100644 --- a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp @@ -879,8 +879,8 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru lu << "Debug(\"" << node_name << ", " << out_name << member_name << "_f32\", " << "fp32tensors, \"" << join(kernel->m_context->input_names) << "\", " << kernel->m_context->outputs[i]->size(false) << ");\n"; - lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, " - << max_tensor_size <<"));\n"; + lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, " << max_tensor_size + << "));\n"; } else if (element::get_backend_cstring( kernel->m_context->outputs[i]->get_element_type()) == "float") diff --git a/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp b/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp index 6781d5702..b4f879e8d 100644 --- a/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp +++ b/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp @@ -19,6 +19,7 @@ using namespace nnfusion::kernels; DEFINE_string(ftune_output_file, "", "the output json file path"); DEFINE_string(ftune_input_file, "", "the input json file path"); +DEFINE_bool(fnofuse, false, "Disable element-wise fusion"); DEFINE_string(ffusion_skiplist, "", "List of op types that skips in fusion"); DECLARE_string(fdefault_device); @@ -84,6 +85,14 @@ namespace }); return nodes; } + + string ir_add_tag(const string& ir, const string& tag) + { + if (ir.find("## @:") != string::npos) + return ir + "|" + tag; + else + return ir + "## @: " + tag; + } } class RegisterFusionOptimizer @@ -138,11 +147,13 @@ class RegisterFusionOptimizer fuse_from_node(tnode, true); } } + inline_lightweighted_ops(); auto groups = extract_fusion_group(); - for (auto group : groups) - { - insert_fuse_group(group); - } + if (!FLAGS_fnofuse) + for (auto group : groups) + { + insert_fuse_group(group); + } auto nodes = nlohmann::json().array(); for (auto& node : find_topo_sort_priority(m_graph)) { @@ -151,10 +162,12 @@ class RegisterFusionOptimizer auto str = nnfusion::op::get_translation_v2(node); if (skip_ops.count(node->get_op_type())) { - if (str.find("## @:") != string::npos) - str += "|skip"; - else - str += "## @: skip"; + str = ir_add_tag(str, "skip"); + } + if (node->get_op_type() == "Fused" && + std::dynamic_pointer_cast(node->get_op_ptr())->get_is_memcpy()) + { + str = ir_add_tag(str, "memcpy"); } auto edge = nlohmann::json().array(); for (auto& e : node->get_in_edges()) @@ -173,7 +186,7 @@ class RegisterFusionOptimizer } private: - vector> extract_fusion_group() + vector> extract_fusion_group() const { unordered_map> groups; vector> result; @@ -195,6 +208,85 @@ class RegisterFusionOptimizer return result; } + bool is_lightweighted_op(const shared_ptr& node) + { + auto type = node->get_op_type(); + if (type == "Slice" || type == "Broadcast") + return true; + if (type == "Reshape") + { + auto op = std::dynamic_pointer_cast(node->get_op_ptr()); + auto order = op->get_input_order(); + if (order.empty()) + return true; + + bool is_lower_dim_kept = order.back() == order.size() - 1; + return is_lower_dim_kept; + } + return false; + } + + void inline_lightweighted_ops() + { + // Iterate over all independent groups + // inline first group into second if: + // 1. first group has one output + // 2. first group are all light weighted ops + // 3. all ops not in skip lists + unordered_map> map; + vector> groups; + for (auto& tnode : node_list_) + { + if (tnode->node_->get_op_ptr()->is_tensor_op()) + continue; + if (tnode->group_id_ < 0) + { + auto f = make_shared(); + f->nodes.insert(tnode->node_); + groups.push_back(f); + } + else + { + if (!map.count(tnode->group_id_)) + { + map[tnode->group_id_] = make_shared(); + } + map[tnode->group_id_]->nodes.insert(tnode->node_); + } + } + for (auto& kv : map) + groups.push_back(kv.second); + + for (auto& group : groups) + { + bool group_is_lightweighted = true; + unordered_set> group_outputs; + for (auto& node : group->nodes) + { + group_is_lightweighted &= is_lightweighted_op(node); + for (auto& edge : node->get_out_edges()) + { + if (!group->nodes.count(edge->get_dst())) + group_outputs.insert(edge->get_dst()); + } + } + if (group_outputs.size() == 0) + continue; + auto& output_node = *group_outputs.begin(); + auto& tag_output_node = node_map_[output_node]; + bool op_skip = skip_ops.count(output_node->get_op_type()); + for (auto& node : group->nodes) + op_skip |= skip_ops.count(node->get_op_type()); + if (group_is_lightweighted && !op_skip && group_outputs.size() == 1) + { + if (tag_output_node->group_id_ < 0) + tag_output_node->group_id_ = cur_group_++; + for (auto& node : group->nodes) + node_map_[node]->group_id_ = tag_output_node->group_id_; + } + } + } + void insert_fuse_group(shared_ptr group) { // get a meaningful name @@ -453,4 +545,4 @@ bool RegisterFusionPass::run_on_graph(std::shared_ptr& graph) applier.apply(FLAGS_ftune_input_file); NNFUSION_LOG(INFO) << "RegisterFusionPass Done"; return true; -} +} \ No newline at end of file diff --git a/src/nnfusion/frontend/onnx_import/onnx.cpp b/src/nnfusion/frontend/onnx_import/onnx.cpp index 6bfdffa7b..d3a96b97c 100644 --- a/src/nnfusion/frontend/onnx_import/onnx.cpp +++ b/src/nnfusion/frontend/onnx_import/onnx.cpp @@ -68,10 +68,9 @@ namespace nnfusion "(models/pytorch2onnx/increase_precision.py)"; string script_path = nnfusion::codegen::get_file_from_templates("onnx/increase_precision.py"); - string cmd = "python3 " + script_path + - " --file " + - m_path + " --mp_file " + mp_filename; - + string cmd = + "python3 " + script_path + " --file " + m_path + " --mp_file " + mp_filename; + int sys_ret = system(cmd.c_str()); // NNFUSION_LOG(INFO) << "mix precision model path: " << mp_filename; opt_fin = std::ifstream(mp_filename.c_str()); @@ -86,7 +85,7 @@ namespace nnfusion "check error messages reported by the tool, fallback"; } } - + string optimized_filename = string(nnfusion::tmpnam(nullptr)); if (FLAGS_fort_folding) { @@ -112,6 +111,7 @@ namespace nnfusion dim_params_str += "}\'"; cmd += dim_params_str; } + NNFUSION_LOG(INFO) << "Executing: " << cmd; int sys_ret = system(cmd.c_str()); opt_fin = std::ifstream(optimized_filename.c_str()); if (sys_ret == 0 && opt_fin.is_open()) @@ -128,11 +128,11 @@ namespace nnfusion std::ifstream ifs{m_path, std::ios::in | std::ios::binary}; NNFUSION_CHECK(ifs.is_open()) << "failure opening file:" + path; string model_dir = ""; - string weight_path = FLAGS_fincrease_precision ? m_path : path; - auto pos = weight_path.rfind("/"); + // string weight_path = FLAGS_fincrease_precision ? m_path : path; + auto pos = m_path.rfind("/"); if (pos != std::string::npos) { - model_dir = weight_path.substr(0, pos); + model_dir = m_path.substr(0, pos); } auto graph = load_onnx_model(ifs, model_dir, dim_params); @@ -141,7 +141,10 @@ namespace nnfusion { remove(optimized_filename.c_str()); } - + if (std::ifstream((optimized_filename + ".data").c_str()).good()) + { + remove((optimized_filename + ".data").c_str()); + } return graph; } } // namespace frontend diff --git a/src/nnfusion/frontend/onnx_import/op/const_of_shape.hpp b/src/nnfusion/frontend/onnx_import/op/const_of_shape.hpp index a7128c1d4..af20bef65 100644 --- a/src/nnfusion/frontend/onnx_import/op/const_of_shape.hpp +++ b/src/nnfusion/frontend/onnx_import/op/const_of_shape.hpp @@ -52,19 +52,27 @@ namespace nnfusion NNFUSION_CHECK(nnfusion::shape_size(value.get_shape()) == 1); const_op = make_constant_op( value.get_ng_type(), - Shape(std::begin(output_shape), std::end(output_shape)), + Shape{1}, value); + // const_op = make_constant_op( + // value.get_ng_type(), + // Shape(std::begin(output_shape), std::end(output_shape)), + // value); } else { auto vec = std::vector{0}; const_op = std::make_shared(element::f32, Shape{1}, vec); } - - const_op->set_name(node_proto.output(0)); - const_op->set_global_consistent_name(node_proto.output(0)); + // const_op->set_name(node_proto.output(0)); + // const_op->set_global_consistent_name(node_proto.output(0)); auto const_gnode = m_graph->add_node_and_edge(const_op, graph::GNodeVector({})); + const_gnode = make_broadcast_node(const_gnode, Shape(std::begin(output_shape), std::end(output_shape)), m_graph); + const_gnode->get_op_ptr()->set_name(node_proto.output(0)); + const_gnode->get_op_ptr()->set_global_consistent_name(node_proto.output(0)); + + return {{node_proto.output(0), const_gnode}}; } diff --git a/src/nnfusion/frontend/onnx_import/op/range.hpp b/src/nnfusion/frontend/onnx_import/op/range.hpp index 30efe0215..4ffe3b043 100644 --- a/src/nnfusion/frontend/onnx_import/op/range.hpp +++ b/src/nnfusion/frontend/onnx_import/op/range.hpp @@ -89,6 +89,7 @@ namespace nnfusion else NNFUSION_CHECK_FAIL() << "non-supported data type for Range op: " << element_type.c_type_string(); + return {}; } } // namespace set_11 diff --git a/src/nnfusion/frontend/onnx_import/op/unaryop.hpp b/src/nnfusion/frontend/onnx_import/op/unaryop.hpp index cdaeb9136..8f002f2d6 100644 --- a/src/nnfusion/frontend/onnx_import/op/unaryop.hpp +++ b/src/nnfusion/frontend/onnx_import/op/unaryop.hpp @@ -65,6 +65,12 @@ namespace nnfusion { using set_1::TranslateUnaryOp; } + + namespace set_11 + { + using set_1::TranslateUnaryOp; + } + namespace set_13 { using set_1::TranslateUnaryOp; diff --git a/src/nnfusion/frontend/onnx_import/ops_bridge.cpp b/src/nnfusion/frontend/onnx_import/ops_bridge.cpp index 5d13e87e0..17d98c3d4 100644 --- a/src/nnfusion/frontend/onnx_import/ops_bridge.cpp +++ b/src/nnfusion/frontend/onnx_import/ops_bridge.cpp @@ -433,6 +433,7 @@ namespace nnfusion REGISTER_OPERATOR("Relu", 1, TranslateUnaryOp); REGISTER_OPERATOR("Reshape", 1, TranslateReshapeOp); REGISTER_OPERATOR("ReshapeGrad", 1, TranslateReshapeGradOp); + REGISTER_OPERATOR("Round", 11, TranslateUnaryOp); //REGISTER_OPERATOR("Selu", 1, selu); REGISTER_OPERATOR("Shape", 1, TranslateShapeOp); REGISTER_OPERATOR("Shape", 15, TranslateShapeOp);