diff --git a/paddle2onnx/legacy/__init__.py b/paddle2onnx/legacy/__init__.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/command.py b/paddle2onnx/legacy/command.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/convert.py b/paddle2onnx/legacy/convert.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/graph/dygraph_helper.py b/paddle2onnx/legacy/graph/dygraph_helper.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/graph/graph.py b/paddle2onnx/legacy/graph/graph.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/graph/onnx_graph.py b/paddle2onnx/legacy/graph/onnx_graph.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/graph/paddle_graph.py b/paddle2onnx/legacy/graph/paddle_graph.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/activation.py b/paddle2onnx/legacy/op_mapper/activation.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/custom_paddle_op/anchor_generator.py b/paddle2onnx/legacy/op_mapper/custom_paddle_op/anchor_generator.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/custom_paddle_op/box_clip.py b/paddle2onnx/legacy/op_mapper/custom_paddle_op/box_clip.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/custom_paddle_op/collect_fpn_proposals.py b/paddle2onnx/legacy/op_mapper/custom_paddle_op/collect_fpn_proposals.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/custom_paddle_op/deformable_conv.py b/paddle2onnx/legacy/op_mapper/custom_paddle_op/deformable_conv.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/custom_paddle_op/distribute_fpn_proposals.py b/paddle2onnx/legacy/op_mapper/custom_paddle_op/distribute_fpn_proposals.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/custom_paddle_op/generate_proposals.py b/paddle2onnx/legacy/op_mapper/custom_paddle_op/generate_proposals.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/custom_paddle_op/grid_sampler.py b/paddle2onnx/legacy/op_mapper/custom_paddle_op/grid_sampler.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/detection/box_coder.py b/paddle2onnx/legacy/op_mapper/detection/box_coder.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/detection/multiclass_nms.py b/paddle2onnx/legacy/op_mapper/detection/multiclass_nms.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/detection/yolo_box.py b/paddle2onnx/legacy/op_mapper/detection/yolo_box.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/logic.py b/paddle2onnx/legacy/op_mapper/logic.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/mapper_helper.py b/paddle2onnx/legacy/op_mapper/mapper_helper.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/math.py b/paddle2onnx/legacy/op_mapper/math.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/nn.py b/paddle2onnx/legacy/op_mapper/nn.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/op_mapper.py b/paddle2onnx/legacy/op_mapper/op_mapper.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/search.py b/paddle2onnx/legacy/op_mapper/search.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/op_mapper/tensor.py b/paddle2onnx/legacy/op_mapper/tensor.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/passes/__init__.py b/paddle2onnx/legacy/passes/__init__.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/passes/dumplicate_names_pass.py b/paddle2onnx/legacy/passes/dumplicate_names_pass.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/legacy/passes/inplace_node_pass.py b/paddle2onnx/legacy/passes/inplace_node_pass.py old mode 100755 new mode 100644 diff --git a/paddle2onnx/mapper/elementwise.cc b/paddle2onnx/mapper/elementwise.cc index e766d7a33..6cbb621b2 100755 --- a/paddle2onnx/mapper/elementwise.cc +++ b/paddle2onnx/mapper/elementwise.cc @@ -26,11 +26,18 @@ REGISTER_MAPPER(elementwise_mod, ElementWiseModMapper) REGISTER_MAPPER(elementwise_floordiv, ElementWiseFloordivMapper) int32_t ElementwiseMapper::GetMinOpset(bool verbose) { + int opset = 7; if (OpType() == "elementwise_min" || OpType() == "elementwise_max") { - Logger(verbose, 8) << RequireOpset(8) << std::endl; - return 8; + auto input_x_info = GetInput("X"); + if(input_x_info[0].dtype == P2ODataType::INT32) + { + opset = 12; + } else { + opset = 8; + } + Logger(verbose, opset) << RequireOpset(opset) << std::endl; } - return 7; + return opset; } void ElementwiseMapper::Opset7() { diff --git a/paddle2onnx/mapper/nn/batch_norm.cc b/paddle2onnx/mapper/nn/batch_norm.cc index dd81622e0..50cb7a76f 100644 --- a/paddle2onnx/mapper/nn/batch_norm.cc +++ b/paddle2onnx/mapper/nn/batch_norm.cc @@ -20,6 +20,18 @@ namespace paddle2onnx { REGISTER_MAPPER(batch_norm, BatchNormMapper) +int32_t BatchNormMapper::GetMinOpset(bool verbose) { + // NHWC is not supported + auto input_info = GetInput("X"); + int opset = 7; + if (input_info[0].dtype == P2ODataType::FP16) { + opset = 15; + Logger(verbose, opset) << RequireOpset(opset) << std::endl; + } + return opset; +} + + void BatchNormMapper::Opset7() { auto input_info = GetInput("X"); auto scale_info = GetInput("Scale"); @@ -31,8 +43,9 @@ void BatchNormMapper::Opset7() { auto node = helper_->MakeNode( "BatchNormalization", {input_info[0].name, scale_info[0].name, bias_info[0].name, - mean_info[0].name, variance_info[0].name}, + mean_info[0].name, variance_info[0].name}, {output_info[0].name}); + if (helper_->GetOpsetVersion() < 9) { int64_t spatial = 1; AddAttribute(node, "spatial", spatial); diff --git a/paddle2onnx/mapper/nn/batch_norm.h b/paddle2onnx/mapper/nn/batch_norm.h index abc649e02..12b2dc5c5 100644 --- a/paddle2onnx/mapper/nn/batch_norm.h +++ b/paddle2onnx/mapper/nn/batch_norm.h @@ -29,6 +29,7 @@ class BatchNormMapper : public Mapper { GetAttr("momentum", &momentum_); } + int32_t GetMinOpset(bool verbose = false); void Opset7(); private: diff --git a/paddle2onnx/mapper/nn/layer_norm.cc b/paddle2onnx/mapper/nn/layer_norm.cc index 5fab8da22..b7371a19f 100644 --- a/paddle2onnx/mapper/nn/layer_norm.cc +++ b/paddle2onnx/mapper/nn/layer_norm.cc @@ -25,8 +25,8 @@ void LayerNormMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Y"); - std::string input_name = helper_->AutoCast( - input_info[0].name, input_info[0].dtype, P2ODataType::FP32); + // LayerNorm support FP32/FP16 + std::string input_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype, P2ODataType::FP32); std::vector input_shape = input_info[0].shape; std::vector axes; diff --git a/paddle2onnx/mapper/nn/pool2d.cc b/paddle2onnx/mapper/nn/pool2d.cc index 0996bca8a..32a69df15 100755 --- a/paddle2onnx/mapper/nn/pool2d.cc +++ b/paddle2onnx/mapper/nn/pool2d.cc @@ -116,12 +116,18 @@ void Pool2dMapper::AdaptivePool(const std::vector& input_info, onnx_pool_type = iter->second[0]; } - std::shared_ptr* node_ptr; - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); - auto node = helper_->MakeNode(onnx_pool_type, {input}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + // AveragePool only support fp16 or fp32. + auto input_name = input_info[0].name; + if((input_info[0].dtype != P2ODataType::FP32) && (input_info[0].dtype != P2ODataType::FP16)){ + input_name = helper_->AutoCast(input_name, input_info[0].dtype, P2ODataType::FP32); + } + auto output_name = output_info[0].name; + if((output_info[0].dtype != P2ODataType::FP32) && (output_info[0].dtype != P2ODataType::FP16)) + { + output_name = helper_->AutoCast(output_name, output_info[0].dtype, P2ODataType::FP32); + } + auto node = helper_->MakeNode(onnx_pool_type, {input_name}, {output_name}); + std::vector kernel_size = {kernel_h, kernel_w}; AddAttribute(node, "kernel_shape", kernel_size); std::vector strides = {stride_h, stride_w}; @@ -165,12 +171,17 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_)); int64_t max_pads = *std::max_element(std::begin(pads_), std::end(pads_)); - auto input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); + + // AveragePool only support fp16 or fp32. + auto input_name = input_info[0].name; + if((input_info[0].dtype != P2ODataType::FP32) && (input_info[0].dtype != P2ODataType::FP16)){ + input_name = helper_->AutoCast(input_name, input_info[0].dtype, P2ODataType::FP32); + } + if (max_ksize <= max_pads) { std::vector onnx_paddings = {0, 0, pads_[0], pads_[1], 0, 0, pads_[2], pads_[3]}; - std::vector inputs_names = {input_x}; + std::vector inputs_names = {input_name}; if (helper_->GetOpsetVersion() >= 11) { std::string paddings_node = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), onnx_paddings); @@ -188,7 +199,7 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, float val = 0.0; AddAttribute(node, "value", val); } - input_x = node->output(0); + input_name = node->output(0); pads_.clear(); pads_.resize(4, 0); } @@ -199,9 +210,13 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[0]; } - auto node = helper_->MakeNode(onnx_pool_type, {input_x}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + + // AveragePool only support fp16 or fp32. + auto output_name = output_info[0].name; + if((output_info[0].dtype != P2ODataType::FP32) && (output_info[0].dtype != P2ODataType::FP16)){ + output_name = helper_->AutoCast(output_name, output_info[0].dtype, P2ODataType::FP32); + } + auto node = helper_->MakeNode(onnx_pool_type, {input_name}, {output_name}); AddAttribute(node, "kernel_shape", k_size_); AddAttribute(node, "strides", strides_); @@ -317,11 +332,9 @@ void Pool2dMapper::Opset7() { auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[1]; } - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); + auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, P2ODataType::FP32); auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0); - helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, output_info[0].dtype); } else if (adaptive_) { AdaptivePool(input_info, output_info); } else { diff --git a/paddle2onnx/mapper/onnx_helper.h b/paddle2onnx/mapper/onnx_helper.h index 5958eb38f..949b5b60f 100644 --- a/paddle2onnx/mapper/onnx_helper.h +++ b/paddle2onnx/mapper/onnx_helper.h @@ -281,6 +281,12 @@ std::string OnnxHelper::Constant(const std::string& output, data.push_back(static_cast(item)); } tensor->set_raw_data(std::string((const char*)(data.data()), numel * 4)); + } else if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT16) { + std::vector data; + for (auto& item : value) { + data.push_back(FP32ToFP16(item)); + } + tensor->set_raw_data(std::string((const char*)(data.data()), numel * sizeof(uint16_t))); } else if (dtype == ONNX_NAMESPACE::TensorProto::DOUBLE) { std::vector data; for (auto& item : value) { @@ -315,7 +321,7 @@ std::string OnnxHelper::Constant(const std::string& output, tensor->set_raw_data(std::string((const char*)(data.data()), numel)); } else { Assert(false, - "Only support data type of BOOL/FLOAT/DOUBLE/INT32/INT64/INT8 in " + "Only support data type of BOOL/FLOAT/FLOTA16/DOUBLE/INT32/INT64/INT8 in " "Constant " "function."); } @@ -353,6 +359,9 @@ std::string OnnxHelper::Constant(const std::string& output, if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT) { std::vector data(numel, static_cast(value)); tensor->set_raw_data(std::string((const char*)(data.data()), numel * 4)); + } else if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT16) { + std::vector data(numel, FP32ToFP16(value)); + tensor->set_raw_data(std::string((const char*)(data.data()), numel * sizeof(uint16_t))); } else if (dtype == ONNX_NAMESPACE::TensorProto::DOUBLE) { std::vector data(numel, static_cast(value)); tensor->set_raw_data(std::string((const char*)(data.data()), numel * 8)); @@ -375,7 +384,7 @@ std::string OnnxHelper::Constant(const std::string& output, } else { Assert( false, - "Only support data type of BOOL/FLOAT/DOUBLE/INT32/INT64 in Constant " + "Only support data type of BOOL/FLOAT/FLOTA16/DOUBLE/INT32/INT64 in Constant " "function."); } nodes.push_back(node); diff --git a/paddle2onnx/mapper/tensor/fill_constant.cc b/paddle2onnx/mapper/tensor/fill_constant.cc index 13b68b528..4a5076858 100644 --- a/paddle2onnx/mapper/tensor/fill_constant.cc +++ b/paddle2onnx/mapper/tensor/fill_constant.cc @@ -25,9 +25,10 @@ int32_t FillConstantMapper::GetMinOpset(bool verbose) { auto onnx_dtype = GetOnnxDtype(out_info[0].dtype); if (onnx_dtype != ONNX_NAMESPACE::TensorProto::INT32 && onnx_dtype != ONNX_NAMESPACE::TensorProto::INT64 && + onnx_dtype != ONNX_NAMESPACE::TensorProto::FLOAT16 && onnx_dtype != ONNX_NAMESPACE::TensorProto::FLOAT && onnx_dtype != ONNX_NAMESPACE::TensorProto::DOUBLE) { - Error() << "Only support int32/int64/float32/float64 data type in " + Error() << "Only support int32/int64/float16/float32/float64 data type in " "fill_constant operator." << std::endl; return -1; @@ -124,6 +125,11 @@ void FillConstantMapper::Opset9() { data[0] = static_cast(value); auto ptr = reinterpret_cast(data.data()); tensor->set_raw_data(std::string(ptr, sizeof(int64_t))); + } else if (onnx_dtype == ONNX_NAMESPACE::TensorProto::FLOAT16) { + std::vector data(1); + data[0] = FP32ToFP16(value); + auto ptr = reinterpret_cast(data.data()); + tensor->set_raw_data(std::string(ptr, sizeof(int16_t))); } else if (onnx_dtype == ONNX_NAMESPACE::TensorProto::FLOAT) { std::vector data(1, value_); auto ptr = reinterpret_cast(data.data()); diff --git a/paddle2onnx/mapper/tensor/matmul.cc b/paddle2onnx/mapper/tensor/matmul.cc index fd8bb1264..5e80d1e0a 100644 --- a/paddle2onnx/mapper/tensor/matmul.cc +++ b/paddle2onnx/mapper/tensor/matmul.cc @@ -19,11 +19,8 @@ namespace paddle2onnx { REGISTER_MAPPER(matmul, MatmulMapper) std::string MatmulMapper::GetTrans(std::vector& input_info) { - std::string castd_name = input_info[0].name; - if (input_info[0].dtype == P2ODataType::FP64) { - castd_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); - } + std::string castd_name = helper_->AutoCast( + input_info[0].name, input_info[0].dtype, P2ODataType::FP32); std::vector perm = Arange(0, input_info[0].Rank()); std::swap(perm[perm.size() - 1], perm[perm.size() - 2]); auto transpose_node = helper_->MakeNode("Transpose", {castd_name}); @@ -35,26 +32,31 @@ void MatmulMapper::Opset7() { auto input_x_info = GetInput("X"); auto input_y_info = GetInput("Y"); auto output_info = GetOutput("Out"); - std::string input_x = input_x_info[0].name; + + // When the data types of input parameters are inconsistent, + // it is necessary to synchronize the data types. if (transpose_X_) { - input_x = GetTrans(input_x_info); + input_x_info[0].name = GetTrans(input_x_info); + input_x_info[0].dtype = P2ODataType::FP32; } - std::string input_y = input_y_info[0].name; if (transpose_Y_) { - input_y = GetTrans(input_y_info); + input_y_info[0].name = GetTrans(input_y_info); + input_y_info[0].dtype = P2ODataType::FP32; } + + if(input_x_info[0].dtype != input_y_info[0].dtype) + { + input_y_info[0].name = helper_->AutoCast( + input_y_info[0].name, input_y_info[0].dtype, input_x_info[0].dtype); + } + if (fabs(alpha_ - 1.0) < 1e-6) { - auto node = helper_->MakeNode("MatMul", {input_x, input_y}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - input_y_info[0].dtype); + helper_->MakeNode("MatMul", {input_x_info[0].name, input_y_info[0].name}, {output_info[0].name}); } else { - auto mutmul_node = helper_->MakeNode("MatMul", {input_x, input_y}); + auto mutmul_node = helper_->MakeNode("MatMul", {input_x_info[0].name, input_y_info[0].name}); std::string scale_node = helper_->Constant({1}, GetOnnxDtype(input_x_info[0].dtype), alpha_); - auto mul_node = - helper_->MakeNode("Mul", {mutmul_node->output(0), scale_node}); - helper_->AutoCast(mul_node->output(0), output_info[0].name, - P2ODataType::FP32, input_y_info[0].dtype); + helper_->MakeNode("Mul", {mutmul_node->output(0), scale_node}, {output_info[0].name}); } } diff --git a/paddle2onnx/mapper/tensor/matmul_v2.cc b/paddle2onnx/mapper/tensor/matmul_v2.cc index f78a26f42..ff0427abb 100644 --- a/paddle2onnx/mapper/tensor/matmul_v2.cc +++ b/paddle2onnx/mapper/tensor/matmul_v2.cc @@ -35,17 +35,22 @@ void MatmulV2Mapper::Opset7() { auto input_x_info = GetInput("X"); auto input_y_info = GetInput("Y"); auto output_info = GetOutput("Out"); - std::string input_x = input_x_info[0].name; if (trans_x_) { - input_x = GetTrans(input_x_info); + input_x_info[0].name = GetTrans(input_x_info); + input_x_info[0].dtype = P2ODataType::FP32; } - std::string input_y = input_y_info[0].name; if (trans_y_) { - input_y = GetTrans(input_y_info); + input_y_info[0].name = GetTrans(input_y_info); + input_y_info[0].dtype = P2ODataType::FP32; } - auto node = helper_->MakeNode("MatMul", {input_x, input_y}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - input_y_info[0].dtype); + + if(input_x_info[0].dtype != input_y_info[0].dtype) + { + input_y_info[0].name = helper_->AutoCast( + input_y_info[0].name, input_y_info[0].dtype, input_x_info[0].dtype); + } + + helper_->MakeNode("MatMul", {input_x_info[0].name, input_y_info[0].name}, {output_info[0].name}); } } // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/scale.cc b/paddle2onnx/mapper/tensor/scale.cc index 810b3472c..83f34de14 100644 --- a/paddle2onnx/mapper/tensor/scale.cc +++ b/paddle2onnx/mapper/tensor/scale.cc @@ -29,48 +29,54 @@ void ScaleMapper::Opset7() { if (!has_scale_tensor && is_scale_1 && is_bias_0) { helper_->MakeNode("Identity", {input_info[0].name}, {output_info[0].name}); } else { - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); - std::string out = input; + // Scale supports FP32/FP16, so only Y and X aligned data types are required. + auto p2o_support_type = P2ODataType::FP32; + auto onnx_support_type = ONNX_NAMESPACE::TensorProto::FLOAT; + auto input_name = input_info[0].name; + if(input_info[0].dtype == P2ODataType::FP16) { + p2o_support_type = P2ODataType::FP16; + onnx_support_type = ONNX_NAMESPACE::TensorProto::FLOAT16; + } else { + input_name = helper_->AutoCast(input_name, input_info[0].dtype, p2o_support_type); + } + std::string out = input_name; if (bias_after_scale_) { if (!is_scale_1 || HasInput("ScaleTensor")) { if (HasInput("ScaleTensor")) { auto scale_info = GetInput("ScaleTensor"); - auto scale = helper_->AutoCast( - scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32); + auto scale = helper_->AutoCast(scale_info[0].name, scale_info[0].dtype, p2o_support_type); out = helper_->MakeNode("Mul", {out, scale})->output(0); } else { auto scale = - helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_); + helper_->Constant({}, onnx_support_type, scale_); out = helper_->MakeNode("Mul", {out, scale})->output(0); } } if (!is_bias_0) { auto bias = - helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, bias_); + helper_->Constant({}, onnx_support_type, bias_); out = helper_->MakeNode("Add", {out, bias})->output(0); } } else { if (!is_bias_0) { auto bias = - helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, bias_); + helper_->Constant({}, onnx_support_type, bias_); out = helper_->MakeNode("Add", {out, bias})->output(0); } if (!is_scale_1 || HasInput("ScaleTensor")) { if (HasInput("ScaleTensor")) { auto scale_info = GetInput("ScaleTensor"); auto scale = helper_->AutoCast( - scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32); + scale_info[0].name, scale_info[0].dtype, p2o_support_type); out = helper_->MakeNode("Mul", {out, scale})->output(0); } else { auto scale = - helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_); + helper_->Constant({}, onnx_support_type, scale_); out = helper_->MakeNode("Mul", {out, scale})->output(0); } } } - helper_->AutoCast(out, output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + helper_->AutoCast(out, output_info[0].name, p2o_support_type, output_info[0].dtype); } } } // namespace paddle2onnx diff --git a/paddle2onnx/parser/parser.cc b/paddle2onnx/parser/parser.cc index d2c743a2d..a43bacf0a 100755 --- a/paddle2onnx/parser/parser.cc +++ b/paddle2onnx/parser/parser.cc @@ -815,7 +815,6 @@ void PaddleParser::GetGlobalBlockInputOutputInfo() { } int32_t PaddleDataTypeSize(int32_t paddle_dtype) { - Assert(paddle_dtype != FP16, "Float16 is not supported."); if (paddle_dtype == P2ODataType::BOOL) { return sizeof(bool); } else if (paddle_dtype == P2ODataType::INT8) { @@ -826,6 +825,9 @@ int32_t PaddleDataTypeSize(int32_t paddle_dtype) { return sizeof(int32_t); } else if (paddle_dtype == P2ODataType::INT64) { return sizeof(int64_t); + } else if (paddle_dtype == P2ODataType::FP16) { + // C++ does not have native support for FP16, so int16 is used instead. + return sizeof(int16_t); } else if (paddle_dtype == P2ODataType::FP32) { return sizeof(float); } else if (paddle_dtype == P2ODataType::FP64) { diff --git a/paddle2onnx/utils/utils.h b/paddle2onnx/utils/utils.h index 463760703..04158c351 100644 --- a/paddle2onnx/utils/utils.h +++ b/paddle2onnx/utils/utils.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace paddle2onnx { @@ -33,6 +34,19 @@ inline const std::string RequireOpset(const int32_t& opset_version) { std::to_string(opset_version) + "."; } +// from https://blog.csdn.net/q2519008/article/details/129264884 +inline uint16_t FP32ToFP16(float fp32_num) +{ + uint32_t temp_data; + memcpy(&temp_data,&fp32_num,sizeof(float)); + uint16_t t = ((temp_data & 0x007fffff) >> 13) | ((temp_data & 0x80000000) >> 16) | (((temp_data & 0x7f800000) >> 13) - (112 << 10)); + if(temp_data & 0x1000) { + t++; + } + uint16_t fp16 = *(uint16_t*)(&t); + return fp16; +} + class P2OLogger { public: P2OLogger() { diff --git a/tests/run.sh b/tests/run.sh index 42057b17a..d21735562 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -21,6 +21,11 @@ wget -P ~/.cache/paddle/dataset/int8/download/ http://paddle-inference-dist.bj.b mkdir ~/.cache/paddle/dataset/int8/download/small_data/ && tar xf ~/.cache/paddle/dataset/int8/download/calibration_test_data.tar.gz -C ~/.cache/paddle/dataset/int8/download/small_data/ wget https://bj.bcebos.com/paddle2onnx/tests/quantized_models.tar.gz tar xf quantized_models.tar.gz +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet18_infer.tar +tar xf ResNet18_infer.tar +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_small_x0_35_infer.tar +tar xf MobileNetV3_small_x0_35_infer.tar cases=`find . -name "test*.py" | sort` ignore="test_auto_scan_affine_channel.py \ diff --git a/tests/test_fp16_model.py b/tests/test_fp16_model.py new file mode 100644 index 000000000..21507bc1a --- /dev/null +++ b/tests/test_fp16_model.py @@ -0,0 +1,133 @@ +import unittest +import os +import onnxruntime as rt +import numpy as np +import paddle.inference as paddle_infer +import cv2 +from paddle.inference import PrecisionType, PlaceType +from paddle.inference import convert_to_mixed_precision + + +def preprocess(image_path): + """ Preprocess input image file + Args: + image_path(str): Path of input image file + + + Returns: + preprocessed data(np.ndarray): Shape of [N, C, H, W] + """ + + def resize_by_short(im, resize_size): + short_size = min(im.shape[0], im.shape[1]) + scale = 256 / short_size + new_w = int(round(im.shape[1] * scale)) + new_h = int(round(im.shape[0] * scale)) + return cv2.resize(im, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + + def center_crop(im, crop_size): + h, w, c = im.shape + w_start = (w - crop_size) // 2 + h_start = (h - crop_size) // 2 + w_end = w_start + crop_size + h_end = h_start + crop_size + return im[h_start:h_end, w_start:w_end, :] + + def normalize(im, mean, std): + im = im.astype("float32") / 255.0 + # to rgb + im = im[:, :, ::-1] + mean = np.array(mean).reshape((1, 1, 3)).astype("float32") + std = np.array(std).reshape((1, 1, 3)).astype("float32") + return (im - mean) / std + + # resize the short edge to `resize_size` + im = cv2.imread(image_path) + resized_im = resize_by_short(im, 256) + + # crop from center + croped_im = center_crop(resized_im, 224) + + # normalize + normalized_im = normalize(croped_im, [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + + # transpose to NCHW + data = np.expand_dims(normalized_im, axis=0) + data = np.transpose(data, (0, 3, 1, 2)) + + return data + + +def predict_by_onnx(input_data): + sess = rt.InferenceSession('inference_fp16.onnx', providers=['CPUExecutionProvider']) + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + pred_onnx = sess.run([output_name], {input_name: input_data}) + return np.array(pred_onnx[0]) + + +def predict_by_paddle_inference(input_data): + config = paddle_infer.Config("inference_fp16.pdmodel", "inference_fp16.pdiparams") + config.enable_use_gpu(500, 0) + + predictor = paddle_infer.create_predictor(config) + # 获取输入的名称 + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + + # 设置输入 + input_handle.copy_from_cpu(input_data) + + # 运行predictor + predictor.run() + + # 获取输出 + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + output_data = output_handle.copy_to_cpu() # numpy.ndarray类型 + return output_data + + +def creat_fp16_model(src_model, src_params): + black_list = set() + dst_model = "./inference_fp16.pdmodel" + dst_params = "./inference_fp16.pdiparams" + + convert_to_mixed_precision( + src_model, # fp32模型文件路径 + src_params, # fp32权重文件路径 + dst_model, # 混合精度模型文件保存路径 + dst_params, # 混合精度权重文件保存路径 + PrecisionType.Half, # 转换精度,如 PrecisionType.Half + PlaceType.GPU, # 后端,如 PlaceType.GPU + True, # 保留输入输出精度信息,若为 True 则输入输出保留为 fp32 类型,否则转为 precision 类型 + black_list # 黑名单列表,哪些 op 不需要进行精度类型转换 + ) + os.system( + "paddle2onnx --model_dir . --model_filename inference_fp16.pdmodel --params_filename " + "inference_fp16.pdiparams --save_file inference_fp16.onnx --enable_dev_version True --enable_onnx_checker " + "True") + + +def clas_test(src_model, src_params, image_path, rtol=1e-07, atol=0.01): + creat_fp16_model(src_model, src_params) + input_data = preprocess(image_path) + data_onnx = predict_by_onnx(input_data) + data_paddle_inference = predict_by_paddle_inference(input_data) + np.testing.assert_allclose(data_onnx, data_paddle_inference, rtol=rtol, atol=atol) + + +class TestClas(unittest.TestCase): + def test_resnet(self): + clas_test("./ResNet18_infer/inference.pdmodel", "./ResNet18_infer/inference.pdiparams", + "ILSVRC2012_val_00000010.jpeg") + + def test_mobilenetV3(self): + clas_test("./MobileNetV3_small_x0_35_infer/inference.pdmodel", + "./MobileNetV3_small_x0_35_infer/inference.pdiparams", + "ILSVRC2012_val_00000010.jpeg", atol=0.02) + + +if __name__ == '__main__': + unittest.main()