Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon No.82] Support fp16 #1151

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified paddle2onnx/legacy/__init__.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/command.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/convert.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/graph/dygraph_helper.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/graph/graph.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/graph/onnx_graph.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/graph/paddle_graph.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/activation.py
100755 → 100644
Empty file.
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/custom_paddle_op/box_clip.py
100755 → 100644
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/custom_paddle_op/grid_sampler.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/detection/box_coder.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/detection/multiclass_nms.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/detection/yolo_box.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/logic.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/mapper_helper.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/math.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/nn.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/op_mapper.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/search.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/op_mapper/tensor.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/passes/__init__.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/passes/dumplicate_names_pass.py
100755 → 100644
Empty file.
Empty file modified paddle2onnx/legacy/passes/inplace_node_pass.py
100755 → 100644
Empty file.
13 changes: 10 additions & 3 deletions paddle2onnx/mapper/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,18 @@ REGISTER_MAPPER(elementwise_mod, ElementWiseModMapper)
REGISTER_MAPPER(elementwise_floordiv, ElementWiseFloordivMapper)

int32_t ElementwiseMapper::GetMinOpset(bool verbose) {
int opset = 7;
if (OpType() == "elementwise_min" || OpType() == "elementwise_max") {
Logger(verbose, 8) << RequireOpset(8) << std::endl;
return 8;
auto input_x_info = GetInput("X");
if(input_x_info[0].dtype == P2ODataType::INT32)
{
opset = 12;
} else {
opset = 8;
}
Logger(verbose, opset) << RequireOpset(opset) << std::endl;
}
return 7;
return opset;
}

void ElementwiseMapper::Opset7() {
Expand Down
15 changes: 14 additions & 1 deletion paddle2onnx/mapper/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@
namespace paddle2onnx {
REGISTER_MAPPER(batch_norm, BatchNormMapper)

int32_t BatchNormMapper::GetMinOpset(bool verbose) {
// NHWC is not supported
auto input_info = GetInput("X");
int opset = 7;
if (input_info[0].dtype == P2ODataType::FP16) {
opset = 15;
Logger(verbose, opset) << RequireOpset(opset) << std::endl;
}
return opset;
}


void BatchNormMapper::Opset7() {
auto input_info = GetInput("X");
auto scale_info = GetInput("Scale");
Expand All @@ -31,8 +43,9 @@ void BatchNormMapper::Opset7() {
auto node = helper_->MakeNode(
"BatchNormalization",
{input_info[0].name, scale_info[0].name, bias_info[0].name,
mean_info[0].name, variance_info[0].name},
mean_info[0].name, variance_info[0].name},
{output_info[0].name});

if (helper_->GetOpsetVersion() < 9) {
int64_t spatial = 1;
AddAttribute(node, "spatial", spatial);
Expand Down
1 change: 1 addition & 0 deletions paddle2onnx/mapper/nn/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BatchNormMapper : public Mapper {
GetAttr("momentum", &momentum_);
}

int32_t GetMinOpset(bool verbose = false);
void Opset7();

private:
Expand Down
4 changes: 2 additions & 2 deletions paddle2onnx/mapper/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ void LayerNormMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Y");

std::string input_name = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
// LayerNorm support FP32/FP16
std::string input_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype, P2ODataType::FP32);

std::vector<int64_t> input_shape = input_info[0].shape;
std::vector<int64_t> axes;
Expand Down
47 changes: 30 additions & 17 deletions paddle2onnx/mapper/nn/pool2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,18 @@ void Pool2dMapper::AdaptivePool(const std::vector<TensorInfo>& input_info,
onnx_pool_type = iter->second[0];
}

std::shared_ptr<ONNX_NAMESPACE::NodeProto>* node_ptr;
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto node = helper_->MakeNode(onnx_pool_type, {input});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
// AveragePool only support fp16 or fp32.
auto input_name = input_info[0].name;
if((input_info[0].dtype != P2ODataType::FP32) && (input_info[0].dtype != P2ODataType::FP16)){
input_name = helper_->AutoCast(input_name, input_info[0].dtype, P2ODataType::FP32);
}
auto output_name = output_info[0].name;
if((output_info[0].dtype != P2ODataType::FP32) && (output_info[0].dtype != P2ODataType::FP16))
{
output_name = helper_->AutoCast(output_name, output_info[0].dtype, P2ODataType::FP32);
}
auto node = helper_->MakeNode(onnx_pool_type, {input_name}, {output_name});

std::vector<int64_t> kernel_size = {kernel_h, kernel_w};
AddAttribute(node, "kernel_shape", kernel_size);
std::vector<int64_t> strides = {stride_h, stride_w};
Expand Down Expand Up @@ -165,12 +171,17 @@ void Pool2dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,

int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_));
int64_t max_pads = *std::max_element(std::begin(pads_), std::end(pads_));
auto input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);

// AveragePool only support fp16 or fp32.
auto input_name = input_info[0].name;
if((input_info[0].dtype != P2ODataType::FP32) && (input_info[0].dtype != P2ODataType::FP16)){
input_name = helper_->AutoCast(input_name, input_info[0].dtype, P2ODataType::FP32);
}

if (max_ksize <= max_pads) {
std::vector<int64_t> onnx_paddings = {0, 0, pads_[0], pads_[1],
0, 0, pads_[2], pads_[3]};
std::vector<std::string> inputs_names = {input_x};
std::vector<std::string> inputs_names = {input_name};
if (helper_->GetOpsetVersion() >= 11) {
std::string paddings_node =
helper_->Constant(GetOnnxDtype(P2ODataType::INT64), onnx_paddings);
Expand All @@ -188,7 +199,7 @@ void Pool2dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,
float val = 0.0;
AddAttribute(node, "value", val);
}
input_x = node->output(0);
input_name = node->output(0);
pads_.clear();
pads_.resize(4, 0);
}
Expand All @@ -199,9 +210,13 @@ void Pool2dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,
auto iter = op_mapper_.find(pooling_type_);
onnx_pool_type = iter->second[0];
}
auto node = helper_->MakeNode(onnx_pool_type, {input_x});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);

// AveragePool only support fp16 or fp32.
auto output_name = output_info[0].name;
if((output_info[0].dtype != P2ODataType::FP32) && (output_info[0].dtype != P2ODataType::FP16)){
output_name = helper_->AutoCast(output_name, output_info[0].dtype, P2ODataType::FP32);
}
auto node = helper_->MakeNode(onnx_pool_type, {input_name}, {output_name});

AddAttribute(node, "kernel_shape", k_size_);
AddAttribute(node, "strides", strides_);
Expand Down Expand Up @@ -317,11 +332,9 @@ void Pool2dMapper::Opset7() {
auto iter = op_mapper_.find(pooling_type_);
onnx_pool_type = iter->second[1];
}
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
} else if (adaptive_) {
AdaptivePool(input_info, output_info);
} else {
Expand Down
13 changes: 11 additions & 2 deletions paddle2onnx/mapper/onnx_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ std::string OnnxHelper::Constant(const std::string& output,
data.push_back(static_cast<float>(item));
}
tensor->set_raw_data(std::string((const char*)(data.data()), numel * 4));
} else if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT16) {
std::vector<uint16_t> data;
for (auto& item : value) {
data.push_back(FP32ToFP16(item));
}
tensor->set_raw_data(std::string((const char*)(data.data()), numel * sizeof(uint16_t)));
} else if (dtype == ONNX_NAMESPACE::TensorProto::DOUBLE) {
std::vector<double> data;
for (auto& item : value) {
Expand Down Expand Up @@ -315,7 +321,7 @@ std::string OnnxHelper::Constant(const std::string& output,
tensor->set_raw_data(std::string((const char*)(data.data()), numel));
} else {
Assert(false,
"Only support data type of BOOL/FLOAT/DOUBLE/INT32/INT64/INT8 in "
"Only support data type of BOOL/FLOAT/FLOTA16/DOUBLE/INT32/INT64/INT8 in "
"Constant "
"function.");
}
Expand Down Expand Up @@ -353,6 +359,9 @@ std::string OnnxHelper::Constant(const std::string& output,
if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT) {
std::vector<float> data(numel, static_cast<float>(value));
tensor->set_raw_data(std::string((const char*)(data.data()), numel * 4));
} else if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT16) {
std::vector<uint16_t> data(numel, FP32ToFP16(value));
tensor->set_raw_data(std::string((const char*)(data.data()), numel * sizeof(uint16_t)));
} else if (dtype == ONNX_NAMESPACE::TensorProto::DOUBLE) {
std::vector<double> data(numel, static_cast<double>(value));
tensor->set_raw_data(std::string((const char*)(data.data()), numel * 8));
Expand All @@ -375,7 +384,7 @@ std::string OnnxHelper::Constant(const std::string& output,
} else {
Assert(
false,
"Only support data type of BOOL/FLOAT/DOUBLE/INT32/INT64 in Constant "
"Only support data type of BOOL/FLOAT/FLOTA16/DOUBLE/INT32/INT64 in Constant "
"function.");
}
nodes.push_back(node);
Expand Down
8 changes: 7 additions & 1 deletion paddle2onnx/mapper/tensor/fill_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ int32_t FillConstantMapper::GetMinOpset(bool verbose) {
auto onnx_dtype = GetOnnxDtype(out_info[0].dtype);
if (onnx_dtype != ONNX_NAMESPACE::TensorProto::INT32 &&
onnx_dtype != ONNX_NAMESPACE::TensorProto::INT64 &&
onnx_dtype != ONNX_NAMESPACE::TensorProto::FLOAT16 &&
onnx_dtype != ONNX_NAMESPACE::TensorProto::FLOAT &&
onnx_dtype != ONNX_NAMESPACE::TensorProto::DOUBLE) {
Error() << "Only support int32/int64/float32/float64 data type in "
Error() << "Only support int32/int64/float16/float32/float64 data type in "
"fill_constant operator."
<< std::endl;
return -1;
Expand Down Expand Up @@ -124,6 +125,11 @@ void FillConstantMapper::Opset9() {
data[0] = static_cast<int64_t>(value);
auto ptr = reinterpret_cast<char*>(data.data());
tensor->set_raw_data(std::string(ptr, sizeof(int64_t)));
} else if (onnx_dtype == ONNX_NAMESPACE::TensorProto::FLOAT16) {
std::vector<int16_t> data(1);
data[0] = FP32ToFP16(value);
auto ptr = reinterpret_cast<char*>(data.data());
tensor->set_raw_data(std::string(ptr, sizeof(int16_t)));
} else if (onnx_dtype == ONNX_NAMESPACE::TensorProto::FLOAT) {
std::vector<float> data(1, value_);
auto ptr = reinterpret_cast<char*>(data.data());
Expand Down
36 changes: 19 additions & 17 deletions paddle2onnx/mapper/tensor/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ namespace paddle2onnx {
REGISTER_MAPPER(matmul, MatmulMapper)

std::string MatmulMapper::GetTrans(std::vector<TensorInfo>& input_info) {
std::string castd_name = input_info[0].name;
if (input_info[0].dtype == P2ODataType::FP64) {
castd_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
}
std::string castd_name = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
std::vector<int64_t> perm = Arange(0, input_info[0].Rank());
std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
auto transpose_node = helper_->MakeNode("Transpose", {castd_name});
Expand All @@ -35,26 +32,31 @@ void MatmulMapper::Opset7() {
auto input_x_info = GetInput("X");
auto input_y_info = GetInput("Y");
auto output_info = GetOutput("Out");
std::string input_x = input_x_info[0].name;

// When the data types of input parameters are inconsistent,
// it is necessary to synchronize the data types.
if (transpose_X_) {
input_x = GetTrans(input_x_info);
input_x_info[0].name = GetTrans(input_x_info);
input_x_info[0].dtype = P2ODataType::FP32;
}
std::string input_y = input_y_info[0].name;
if (transpose_Y_) {
input_y = GetTrans(input_y_info);
input_y_info[0].name = GetTrans(input_y_info);
input_y_info[0].dtype = P2ODataType::FP32;
}

if(input_x_info[0].dtype != input_y_info[0].dtype)
{
input_y_info[0].name = helper_->AutoCast(
input_y_info[0].name, input_y_info[0].dtype, input_x_info[0].dtype);
}

if (fabs(alpha_ - 1.0) < 1e-6) {
auto node = helper_->MakeNode("MatMul", {input_x, input_y});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
input_y_info[0].dtype);
helper_->MakeNode("MatMul", {input_x_info[0].name, input_y_info[0].name}, {output_info[0].name});
} else {
auto mutmul_node = helper_->MakeNode("MatMul", {input_x, input_y});
auto mutmul_node = helper_->MakeNode("MatMul", {input_x_info[0].name, input_y_info[0].name});
std::string scale_node =
helper_->Constant({1}, GetOnnxDtype(input_x_info[0].dtype), alpha_);
auto mul_node =
helper_->MakeNode("Mul", {mutmul_node->output(0), scale_node});
helper_->AutoCast(mul_node->output(0), output_info[0].name,
P2ODataType::FP32, input_y_info[0].dtype);
helper_->MakeNode("Mul", {mutmul_node->output(0), scale_node}, {output_info[0].name});
}
}

Expand Down
19 changes: 12 additions & 7 deletions paddle2onnx/mapper/tensor/matmul_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,22 @@ void MatmulV2Mapper::Opset7() {
auto input_x_info = GetInput("X");
auto input_y_info = GetInput("Y");
auto output_info = GetOutput("Out");
std::string input_x = input_x_info[0].name;
if (trans_x_) {
input_x = GetTrans(input_x_info);
input_x_info[0].name = GetTrans(input_x_info);
input_x_info[0].dtype = P2ODataType::FP32;
}
std::string input_y = input_y_info[0].name;
if (trans_y_) {
input_y = GetTrans(input_y_info);
input_y_info[0].name = GetTrans(input_y_info);
input_y_info[0].dtype = P2ODataType::FP32;
}
auto node = helper_->MakeNode("MatMul", {input_x, input_y});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
input_y_info[0].dtype);

if(input_x_info[0].dtype != input_y_info[0].dtype)
{
input_y_info[0].name = helper_->AutoCast(
input_y_info[0].name, input_y_info[0].dtype, input_x_info[0].dtype);
}

helper_->MakeNode("MatMul", {input_x_info[0].name, input_y_info[0].name}, {output_info[0].name});
}

} // namespace paddle2onnx
30 changes: 18 additions & 12 deletions paddle2onnx/mapper/tensor/scale.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,54 @@ void ScaleMapper::Opset7() {
if (!has_scale_tensor && is_scale_1 && is_bias_0) {
helper_->MakeNode("Identity", {input_info[0].name}, {output_info[0].name});
} else {
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
std::string out = input;
// Scale supports FP32/FP16, so only Y and X aligned data types are required.
auto p2o_support_type = P2ODataType::FP32;
auto onnx_support_type = ONNX_NAMESPACE::TensorProto::FLOAT;
auto input_name = input_info[0].name;
if(input_info[0].dtype == P2ODataType::FP16) {
p2o_support_type = P2ODataType::FP16;
onnx_support_type = ONNX_NAMESPACE::TensorProto::FLOAT16;
} else {
input_name = helper_->AutoCast(input_name, input_info[0].dtype, p2o_support_type);
}
std::string out = input_name;
if (bias_after_scale_) {
if (!is_scale_1 || HasInput("ScaleTensor")) {
if (HasInput("ScaleTensor")) {
auto scale_info = GetInput("ScaleTensor");
auto scale = helper_->AutoCast(
scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32);
auto scale = helper_->AutoCast(scale_info[0].name, scale_info[0].dtype, p2o_support_type);
out = helper_->MakeNode("Mul", {out, scale})->output(0);
} else {
auto scale =
helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_);
helper_->Constant({}, onnx_support_type, scale_);
out = helper_->MakeNode("Mul", {out, scale})->output(0);
}
}
if (!is_bias_0) {
auto bias =
helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, bias_);
helper_->Constant({}, onnx_support_type, bias_);
out = helper_->MakeNode("Add", {out, bias})->output(0);
}
} else {
if (!is_bias_0) {
auto bias =
helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, bias_);
helper_->Constant({}, onnx_support_type, bias_);
out = helper_->MakeNode("Add", {out, bias})->output(0);
}
if (!is_scale_1 || HasInput("ScaleTensor")) {
if (HasInput("ScaleTensor")) {
auto scale_info = GetInput("ScaleTensor");
auto scale = helper_->AutoCast(
scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32);
scale_info[0].name, scale_info[0].dtype, p2o_support_type);
out = helper_->MakeNode("Mul", {out, scale})->output(0);
} else {
auto scale =
helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_);
helper_->Constant({}, onnx_support_type, scale_);
out = helper_->MakeNode("Mul", {out, scale})->output(0);
}
}
}
helper_->AutoCast(out, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
helper_->AutoCast(out, output_info[0].name, p2o_support_type, output_info[0].dtype);
}
}
} // namespace paddle2onnx
4 changes: 3 additions & 1 deletion paddle2onnx/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,6 @@ void PaddleParser::GetGlobalBlockInputOutputInfo() {
}

int32_t PaddleDataTypeSize(int32_t paddle_dtype) {
Assert(paddle_dtype != FP16, "Float16 is not supported.");
if (paddle_dtype == P2ODataType::BOOL) {
return sizeof(bool);
} else if (paddle_dtype == P2ODataType::INT8) {
Expand All @@ -826,6 +825,9 @@ int32_t PaddleDataTypeSize(int32_t paddle_dtype) {
return sizeof(int32_t);
} else if (paddle_dtype == P2ODataType::INT64) {
return sizeof(int64_t);
} else if (paddle_dtype == P2ODataType::FP16) {
// C++ does not have native support for FP16, so int16 is used instead.
return sizeof(int16_t);
} else if (paddle_dtype == P2ODataType::FP32) {
return sizeof(float);
} else if (paddle_dtype == P2ODataType::FP64) {
Expand Down
Loading