Skip to content

Commit

Permalink
fix onnx serializer for lite (#757)
Browse files Browse the repository at this point in the history
Co-authored-by: bzhang <bzhang@openailab.com>
  • Loading branch information
bzhang5 and bzhang authored Jun 22, 2021
1 parent be9ffe3 commit 2b9e2d2
Showing 1 changed file with 55 additions and 25 deletions.
80 changes: 55 additions & 25 deletions tools/convert_tool/onnx/onnx2tengine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char*
return onnx::TensorProto();
}


/*
* ASSIST FUNCTIONS FOR ONNX SERIALIZER END
*/
Expand Down Expand Up @@ -122,18 +121,6 @@ int onnx_serializer::load_model_file(std::string model_file, onnx::ModelProto &m
return 0;
}

static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
{
for (int i = 0; i < node.attribute_size(); i++)
{
const onnx::AttributeProto& attr = node.attribute(i);
if (attr.name() == key)
{
return attr.t();
}
}
}

int onnx_serializer::load_constant_tensor(ir_graph_t* graph, const onnx::GraphProto& onnx_graph)
{
std::map<std::string, onnx::TensorProto> node_tensor;
Expand Down Expand Up @@ -185,11 +172,9 @@ int onnx_serializer::load_constant_tensor(ir_graph_t* graph, const onnx::GraphPr
set_ir_tensor_shape(ir_tensor, dims, dim_num);
ir_tensor->tensor_type = TENSOR_TYPE_CONST;
// set tensor data

int tensor_size = ir_tensor->elem_size * ir_tensor->elem_num;

if ( 7 == onnx_tensor.data_type())
{
int tensor_size = ir_tensor->elem_num * sizeof(int64_t);
ir_tensor->data = sys_malloc(tensor_size);
int64_t* mem_buf = (int64_t*)ir_tensor->data;
if(onnx_tensor.has_raw_data())
Expand All @@ -211,6 +196,7 @@ int onnx_serializer::load_constant_tensor(ir_graph_t* graph, const onnx::GraphPr
}
else
{
int tensor_size = ir_tensor->elem_num * sizeof(uint8_t);
ir_tensor->data = sys_malloc(tensor_size);
uint8_t* mem_buf = (uint8_t*)ir_tensor->data;
if(onnx_tensor.has_raw_data())
Expand Down Expand Up @@ -275,32 +261,52 @@ int onnx_serializer::load_initializer_tensor(ir_graph_t* graph, const onnx::Grap

if (onnx_tensor.has_raw_data())
{
int tensor_size = ir_tensor->elem_size * ir_tensor->elem_num;
if (onnx_tensor.data_type() == 1) //fp32
{
int tensor_size = ir_tensor->elem_num * sizeof(float);
ir_tensor->data = sys_malloc(tensor_size);
uint8_t* mem_buf = (uint8_t*)ir_tensor->data;
uint8_t* raw_data = (uint8_t*)onnx_tensor.raw_data().c_str();
float* mem_buf = (float*)ir_tensor->data;
float* raw_data = (float*)onnx_tensor.raw_data().c_str();
for (int j = 0; j < ir_tensor->elem_num; j++)
{
mem_buf[j] = raw_data[j];
}
}
else // int32
{
tensor_data_copy(ir_tensor, onnx_tensor, ir_tensor->elem_num, tensor_size, 3);
int tensor_size = ir_tensor->elem_num * sizeof(int64_t);
ir_tensor->data = sys_malloc(tensor_size);
int64_t* mem_buf = (int64_t*)ir_tensor->data;
int64_t* raw_data = (int64_t*)onnx_tensor.raw_data().data();
for (int j = 0; j < ir_tensor->elem_num; j++)
{
mem_buf[j] = raw_data[j];
}
}
}
else
{
int tensor_size = ir_tensor->elem_size * ir_tensor->elem_num;
if (onnx_tensor.data_type() == 1) //fp32
{
tensor_data_copy(ir_tensor, onnx_tensor, ir_tensor->elem_num, tensor_size, 4);
int tensor_size = ir_tensor->elem_num * sizeof(float);
ir_tensor->data = sys_malloc(tensor_size);
float* mem_buf = (float*)ir_tensor->data;
float* raw_data = (float*)onnx_tensor.float_data().data();
for (int j = 0; j < ir_tensor->elem_num; j++)
{
mem_buf[j] = raw_data[j];
}
}
else // int32
{
tensor_data_copy(ir_tensor, onnx_tensor, ir_tensor->elem_num, tensor_size, 5);
int tensor_size = ir_tensor->elem_num * sizeof(int32_t);
ir_tensor->data = sys_malloc(tensor_size);
int32_t* mem_buf = (int32_t*)ir_tensor->data;
int32_t* raw_data = (int32_t*)onnx_tensor.int32_data().data();
for (int j = 0; j < ir_tensor->elem_num; j++)
{
mem_buf[j] = raw_data[j];
}
}
}

Expand All @@ -310,6 +316,32 @@ int onnx_serializer::load_initializer_tensor(ir_graph_t* graph, const onnx::Grap
return 0;
}

int onnx_serializer::check_same_tensor(ir_graph_t* graph, const onnx::GraphProto& onnx_graph)
{
std::vector<std::string> tensor_name_list;


for(int i = 0; i < onnx_graph.node_size(); i++)
{
const onnx::NodeProto& onnx_node = onnx_graph.node(i);
for(int i = 0; i < onnx_node.input_size(); i++)
{
const std::string& input_name = onnx_node.input(i);
if (input_name == "")
{
continue;
}
int tensor_id = get_ir_tensor_index_from_name(graph, input_name.c_str());
ir_tensor_t* tensor = get_ir_graph_tensor(graph, tensor_id);
ir_tensor_t* new_tensor = nullptr;
std::string onnx_tensor_name = input_name;
if(tensor != NULL){
printf("%s \n", input_name.c_str());
}
}
}
return 0;
}

int onnx_serializer::set_graph_input(ir_graph_t* graph, const onnx::GraphProto& onnx_graph)
{
Expand Down Expand Up @@ -397,7 +429,6 @@ int onnx_serializer::load_graph_node(ir_graph_t* graph, const onnx::GraphProto&
if (ir_node == NULL)
return -1;
/* set ir node io */

for (int j = 0; j < onnx_node.input_size(); j++)
{
const std::string& input_name = onnx_node.input(j);
Expand Down Expand Up @@ -453,7 +484,6 @@ int onnx_serializer::load_graph_node(ir_graph_t* graph, const onnx::GraphProto&
}
/* exec op load func */
op_load_t loader = op_load_map[op_name].second;
// printf("%s \n", op_name.c_str());
if (loader(graph, ir_node, onnx_node) < 0)
{
TLOG_ERR("load op %s func failed in node %s .\n", op_name.c_str(), node_name.c_str());
Expand Down

0 comments on commit 2b9e2d2

Please sign in to comment.