diff --git a/CMakeLists.txt b/CMakeLists.txt index e308573ba5..5aef0b56ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,6 +54,23 @@ option(ENABLE_DEBUG "if to enable print debug information, this may reduce perfo # Whether to build fastdeply with vision/text/... examples, only for testings. option(WITH_VISION_EXAMPLES "Whether to build fastdeply with vision examples" OFF) +# Check for 32bit system +if(WIN32) + if(NOT CMAKE_CL_64) + message("***********************Compile on non 64-bit system now**********************") + add_definitions(-DNON_64_PLATFORM) + if(WITH_GPU) + message(FATAL_ERROR "-DWITH_GPU=ON doesn't support on non 64-bit system now.") + endif() + if(ENABLE_PADDLE_BACKEND) + message(FATAL_ERROR "-DENABLE_PADDLE_BACKEND=ON doesn't support on non 64-bit system now.") + endif() + if(ENABLE_VISION) + message(FATAL_ERROR "-DENABLE_VISION=ON doesn't support on non 64-bit system now.") + endif() + endif() +endif() + if(ENABLE_DEBUG) add_definitions(-DFASTDEPLOY_DEBUG) endif() diff --git a/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.cc b/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.cc index 6f9f8f2a7c..a132dbffc3 100644 --- a/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.cc +++ b/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef NON_64_PLATFORM + #include "fastdeploy/backends/ort/ops/multiclass_nms.h" #include #include "fastdeploy/core/fd_tensor.h" @@ -255,3 +257,5 @@ void MultiClassNmsKernel::GetAttribute(const OrtKernelInfo* info) { score_threshold = ort_.KernelInfoGetAttribute(info, "score_threshold"); } } // namespace fastdeploy + +#endif \ No newline at end of file diff --git a/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.h b/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.h index 78f9a22557..4e167d669b 100644 --- a/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.h +++ b/csrcs/fastdeploy/backends/ort/ops/multiclass_nms.h @@ -13,7 +13,10 @@ // limitations under the License. #pragma once + #include + +#ifndef NON_64_PLATFORM #include "onnxruntime_cxx_api.h" // NOLINT namespace fastdeploy { @@ -74,3 +77,5 @@ struct MultiClassNmsOp }; } // namespace fastdeploy + +#endif \ No newline at end of file diff --git a/csrcs/fastdeploy/backends/ort/ort_backend.cc b/csrcs/fastdeploy/backends/ort/ort_backend.cc index 27c746a9ea..9fdb3c66b7 100644 --- a/csrcs/fastdeploy/backends/ort/ort_backend.cc +++ b/csrcs/fastdeploy/backends/ort/ort_backend.cc @@ -292,6 +292,7 @@ TensorInfo OrtBackend::GetOutputInfo(int index) { } void OrtBackend::InitCustomOperators() { +#ifndef NON_64_PLATFORM if (custom_operators_.size() == 0) { MultiClassNmsOp* custom_op = new MultiClassNmsOp{}; custom_operators_.push_back(custom_op); @@ -300,6 +301,7 @@ void OrtBackend::InitCustomOperators() { custom_op_domain_.Add(custom_operators_[i]); } session_options_.Add(custom_op_domain_); +#endif } } // namespace fastdeploy diff --git a/csrcs/fastdeploy/backends/ort/ort_backend.h b/csrcs/fastdeploy/backends/ort/ort_backend.h index 2dab030233..6d87114733 100644 --- a/csrcs/fastdeploy/backends/ort/ort_backend.h +++ b/csrcs/fastdeploy/backends/ort/ort_backend.h @@ -82,7 +82,9 @@ class OrtBackend : public BaseBackend { std::shared_ptr binding_; std::vector inputs_desc_; std::vector outputs_desc_; +#ifndef NON_64_PLATFORM Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle"); +#endif OrtBackendOption option_; void CopyToCpu(const Ort::Value& value, FDTensor* tensor); }; diff --git a/csrcs/fastdeploy/vision.h b/csrcs/fastdeploy/vision.h index 781bf01a2f..3b765c387e 100644 --- a/csrcs/fastdeploy/vision.h +++ b/csrcs/fastdeploy/vision.h @@ -15,6 +15,7 @@ #include "fastdeploy/core/config.h" #ifdef ENABLE_VISION +#include "fastdeploy/vision/biubug6/retinaface.h" #include "fastdeploy/vision/deepcam/yolov5face.h" #include "fastdeploy/vision/linzaer/ultraface.h" #include "fastdeploy/vision/megvii/yolox.h" diff --git a/csrcs/fastdeploy/vision/biubug6/biubug6_pybind.cc b/csrcs/fastdeploy/vision/biubug6/biubug6_pybind.cc new file mode 100644 index 0000000000..78208280ac --- /dev/null +++ b/csrcs/fastdeploy/vision/biubug6/biubug6_pybind.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindBiubug6(pybind11::module& m) { + auto biubug6_module = m.def_submodule( + "biubug6", "https://github.com/biubug6/Pytorch_Retinaface"); + pybind11::class_(biubug6_module, + "RetinaFace") + .def(pybind11::init()) + .def("predict", + [](vision::biubug6::RetinaFace& self, pybind11::array& data, + float conf_threshold, float nms_iou_threshold) { + auto mat = PyArrayToCvMat(data); + vision::FaceDetectionResult res; + self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); + return res; + }) + .def_readwrite("size", &vision::biubug6::RetinaFace::size) + .def_readwrite("variance", &vision::biubug6::RetinaFace::variance) + .def_readwrite("downsample_strides", + &vision::biubug6::RetinaFace::downsample_strides) + .def_readwrite("min_sizes", &vision::biubug6::RetinaFace::min_sizes) + .def_readwrite("landmarks_per_face", + &vision::biubug6::RetinaFace::landmarks_per_face); +} +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/biubug6/retinaface.cc b/csrcs/fastdeploy/vision/biubug6/retinaface.cc new file mode 100644 index 0000000000..2ba1a788e9 --- /dev/null +++ b/csrcs/fastdeploy/vision/biubug6/retinaface.cc @@ -0,0 +1,310 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/biubug6/retinaface.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { + +namespace vision { + +namespace biubug6 { + +struct RetinaAnchor { + float cx; + float cy; + float s_kx; + float s_ky; +}; + +void GenerateRetinaAnchors(const std::vector& size, + const std::vector& downsample_strides, + const std::vector>& min_sizes, + std::vector* anchors) { + // size: tuple of input (width, height) + // downsample_strides: downsample strides (steps), e.g (8,16,32) + // min_sizes: width and height for each anchor, + // e.g {{16, 32}, {64, 128}, {256, 512}} + int h = size[1]; + int w = size[0]; + std::vector> feature_maps; + for (auto s : downsample_strides) { + feature_maps.push_back( + {static_cast( + std::ceil(static_cast(h) / static_cast(s))), + static_cast( + std::ceil(static_cast(w) / static_cast(s)))}); + } + + (*anchors).clear(); + const size_t num_feature_map = feature_maps.size(); + // reference: layers/functions/prior_box.py#L21 + for (size_t k = 0; k < num_feature_map; ++k) { + auto f_map = feature_maps.at(k); // e.g [640//8,640//8] + auto tmp_min_sizes = min_sizes.at(k); // e.g [8,16] + int f_h = f_map.at(0); + int f_w = f_map.at(1); + for (size_t i = 0; i < f_h; ++i) { + for (size_t j = 0; j < f_w; ++j) { + for (auto min_size : tmp_min_sizes) { + float s_kx = + static_cast(min_size) / static_cast(w); // e.g 16/w + float s_ky = + static_cast(min_size) / static_cast(h); // e.g 16/h + // (x + 0.5) * step / w normalized loc mapping to input width + // (y + 0.5) * step / h normalized loc mapping to input height + float s = static_cast(downsample_strides.at(k)); + float cx = (static_cast(j) + 0.5f) * s / static_cast(w); + float cy = (static_cast(i) + 0.5f) * s / static_cast(h); + (*anchors).emplace_back( + RetinaAnchor{cx, cy, s_kx, s_ky}); // without clip + } + } + } + } +} + +RetinaFace::RetinaFace(const std::string& model_file, + const std::string& params_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + if (model_format == Frontend::ONNX) { + valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 + valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool RetinaFace::Initialize() { + // parameters for preprocess + size = {640, 640}; + variance = {0.1f, 0.2f}; + downsample_strides = {8, 16, 32}; + min_sizes = {{16, 32}, {64, 128}, {256, 512}}; + landmarks_per_face = 5; + + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + // Check if the input shape is dynamic after Runtime already initialized, + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + return true; +} + +bool RetinaFace::Preprocess( + Mat* mat, FDTensor* output, + std::map>* im_info) { + // retinaface's preprocess steps + // 1. Resize + // 2. Convert(opencv style) or Normalize + // 3. HWC->CHW + int resize_w = size[0]; + int resize_h = size[1]; + if (resize_h != mat->Height() || resize_w != mat->Width()) { + Resize::Run(mat, resize_w, resize_h); + } + + // Compute `result = mat * alpha + beta` directly by channel + // Reference: detect.py#L94 + std::vector alpha = {1.f, 1.f, 1.f}; + std::vector beta = {-104.f, -117.f, -123.f}; // BGR; + Convert::Run(mat, alpha, beta); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +bool RetinaFace::Postprocess( + std::vector& infer_result, FaceDetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold) { + // retinaface has 3 output tensors, boxes & conf & landmarks + FDASSERT( + (infer_result.size() == 3), + "The default number of output tensor must be 3 according to retinaface."); + FDTensor& boxes_tensor = infer_result.at(0); // (1,n,4) + FDTensor& conf_tensor = infer_result.at(1); // (1,n,2) + FDTensor& landmarks_tensor = infer_result.at(2); // (1,n,10) + FDASSERT((boxes_tensor.shape[0] == 1), "Only support batch =1 now."); + if (boxes_tensor.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + + result->Clear(); + // must be setup landmarks_per_face before reserve + result->landmarks_per_face = landmarks_per_face; + result->Reserve(boxes_tensor.shape[1]); + + float* boxes_ptr = static_cast(boxes_tensor.Data()); + float* conf_ptr = static_cast(conf_tensor.Data()); + float* landmarks_ptr = static_cast(landmarks_tensor.Data()); + const size_t num_bboxes = boxes_tensor.shape[1]; // n + // fetch original image shape + auto iter_ipt = im_info.find("input_shape"); + FDASSERT((iter_ipt != im_info.end()), + "Cannot find input_shape from im_info."); + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + + // generate anchors with dowmsample strides + std::vector anchors; + GenerateRetinaAnchors(size, downsample_strides, min_sizes, &anchors); + + // decode bounding boxes + for (size_t i = 0; i < num_bboxes; ++i) { + float confidence = conf_ptr[2 * i + 1]; + // filter boxes by conf_threshold + if (confidence <= conf_threshold) { + continue; + } + float prior_cx = anchors.at(i).cx; + float prior_cy = anchors.at(i).cy; + float prior_s_kx = anchors.at(i).s_kx; + float prior_s_ky = anchors.at(i).s_ky; + + // fetch offsets (dx,dy,dw,dh) + float dx = boxes_ptr[4 * i + 0]; + float dy = boxes_ptr[4 * i + 1]; + float dw = boxes_ptr[4 * i + 2]; + float dh = boxes_ptr[4 * i + 3]; + // reference: Pytorch_Retinaface/utils/box_utils.py + float x = prior_cx + dx * variance[0] * prior_s_kx; + float y = prior_cy + dy * variance[0] * prior_s_ky; + float w = prior_s_kx * std::exp(dw * variance[1]); + float h = prior_s_ky * std::exp(dh * variance[1]); // (0.~1.) + // from (x,y,w,h) to (x1,y1,x2,y2) + float x1 = (x - w / 2.f) * ipt_w; + float y1 = (y - h / 2.f) * ipt_h; + float x2 = (x + w / 2.f) * ipt_w; + float y2 = (y + h / 2.f) * ipt_h; + result->boxes.emplace_back(std::array{x1, y1, x2, y2}); + result->scores.push_back(confidence); + // decode landmarks (default 5 landmarks) + if (landmarks_per_face > 0) { + // reference: utils/box_utils.py#L241 + for (size_t j = 0; j < landmarks_per_face * 2; j += 2) { + float ldx = landmarks_ptr[i * (landmarks_per_face * 2) + (j + 0)]; + float ldy = landmarks_ptr[i * (landmarks_per_face * 2) + (j + 1)]; + float lx = (prior_cx + ldx * variance[0] * prior_s_kx) * ipt_w; + float ly = (prior_cy + ldy * variance[0] * prior_s_ky) * ipt_h; + result->landmarks.emplace_back(std::array{lx, ly}); + } + } + } + + if (result->boxes.size() == 0) { + return true; + } + + utils::NMS(result, nms_iou_threshold); + + // scale and clip box + for (size_t i = 0; i < result->boxes.size(); ++i) { + result->boxes[i][0] = std::max(result->boxes[i][0], 0.0f); + result->boxes[i][1] = std::max(result->boxes[i][1], 0.0f); + result->boxes[i][2] = std::max(result->boxes[i][2], 0.0f); + result->boxes[i][3] = std::max(result->boxes[i][3], 0.0f); + result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f); + result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f); + result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f); + result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f); + } + // scale and clip landmarks + for (size_t i = 0; i < result->landmarks.size(); ++i) { + result->landmarks[i][0] = std::max(result->landmarks[i][0], 0.0f); + result->landmarks[i][1] = std::max(result->landmarks[i][1], 0.0f); + result->landmarks[i][0] = std::min(result->landmarks[i][0], ipt_w - 1.0f); + result->landmarks[i][1] = std::min(result->landmarks[i][1], ipt_h - 1.0f); + } + return true; +} + +bool RetinaFace::Predict(cv::Mat* im, FaceDetectionResult* result, + float conf_threshold, float nms_iou_threshold) { +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_START(0) +#endif + + Mat mat(*im); + std::vector input_tensors(1); + + std::map> im_info; + + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + im_info["output_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(0, "Preprocess") + TIMERECORD_START(1) +#endif + + input_tensors[0].name = InputInfoOfRuntime(0).name; + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(1, "Inference") + TIMERECORD_START(2) +#endif + + if (!Postprocess(output_tensors, result, im_info, conf_threshold, + nms_iou_threshold)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(2, "Postprocess") +#endif + return true; +} + +} // namespace biubug6 +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/csrcs/fastdeploy/vision/biubug6/retinaface.h b/csrcs/fastdeploy/vision/biubug6/retinaface.h new file mode 100644 index 0000000000..d16942f13b --- /dev/null +++ b/csrcs/fastdeploy/vision/biubug6/retinaface.h @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { + +namespace vision { + +namespace biubug6 { + +class FASTDEPLOY_DECL RetinaFace : public FastDeployModel { + public: + // 当model_format为ONNX时,无需指定params_file + // 当model_format为Paddle时,则需同时指定model_file & params_file + RetinaFace(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX); + + // 定义模型的名称 + std::string ModelName() const { return "biubug6/Pytorch_Retinaface"; } + + // 模型预测接口,即用户调用的接口 + // im 为用户的输入数据,目前对于CV均定义为cv::Mat + // result 为模型预测的输出结构体 + // conf_threshold 为后处理的参数 + // nms_iou_threshold 为后处理的参数 + virtual bool Predict(cv::Mat* im, FaceDetectionResult* result, + float conf_threshold = 0.25f, + float nms_iou_threshold = 0.4f); + + // 以下为模型在预测时的一些参数,基本是前后处理所需 + // 用户在创建模型后,可根据模型的要求,以及自己的需求 + // 对参数进行修改 + // tuple of (width, height), default (640, 640) + std::vector size; + // variance in RetinaFace's prior-box(anchor) generate process, + // default (0.1, 0.2) + std::vector variance; + // downsample strides (namely, steps) for RetinaFace to + // generate anchors, will take (8,16,32) as default values. + std::vector downsample_strides; + // min sizes, width and height for each anchor. + std::vector> min_sizes; + // landmarks_per_face, default 5 in RetinaFace + int landmarks_per_face; + + private: + // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 + bool Initialize(); + + // 输入图像预处理操作 + // Mat为FastDeploy定义的数据结构 + // FDTensor为预处理后的Tensor数据,传给后端进行推理 + // im_info为预处理过程保存的数据,在后处理中需要用到 + bool Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + + // 后端推理结果后处理,输出给用户 + // infer_result 为后端推理后的输出Tensor + // result 为模型预测的结果 + // im_info 为预处理记录的信息,后处理用于还原box + // conf_threshold 后处理时过滤box的置信度阈值 + // nms_iou_threshold 后处理时NMS设定的iou阈值 + bool Postprocess(std::vector& infer_result, + FaceDetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold); + + // 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致 + bool IsDynamicInput() const { return is_dynamic_input_; } + + bool is_dynamic_input_; +}; + +} // namespace biubug6 +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/deepcam/yolov5face.cc b/csrcs/fastdeploy/vision/deepcam/yolov5face.cc index 5b2c77af9d..599821b4cd 100644 --- a/csrcs/fastdeploy/vision/deepcam/yolov5face.cc +++ b/csrcs/fastdeploy/vision/deepcam/yolov5face.cc @@ -155,14 +155,16 @@ bool YOLOv5Face::Postprocess( float conf_threshold, float nms_iou_threshold) { // infer_result: (1,n,16) 16=4+1+10+1 FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now."); - result->Clear(); - // must be setup landmarks_per_face before reserve - result->landmarks_per_face = landmarks_per_face; - result->Reserve(infer_result.shape[1]); if (infer_result.dtype != FDDataType::FP32) { FDERROR << "Only support post process with float32 data." << std::endl; return false; } + + result->Clear(); + // must be setup landmarks_per_face before reserve + result->landmarks_per_face = landmarks_per_face; + result->Reserve(infer_result.shape[1]); + float* data = static_cast(infer_result.Data()); for (size_t i = 0; i < infer_result.shape[1]; ++i) { float* reg_cls_ptr = data + (i * infer_result.shape[2]); diff --git a/csrcs/fastdeploy/vision/linzaer/ultraface.cc b/csrcs/fastdeploy/vision/linzaer/ultraface.cc index e9148604f2..7c35059c60 100644 --- a/csrcs/fastdeploy/vision/linzaer/ultraface.cc +++ b/csrcs/fastdeploy/vision/linzaer/ultraface.cc @@ -106,11 +106,6 @@ bool UltraFace::Postprocess( FDTensor& boxes_tensor = infer_result.at(1); // (1,4420,4) FDASSERT((scores_tensor.shape[0] == 1), "Only support batch =1 now."); FDASSERT((boxes_tensor.shape[0] == 1), "Only support batch =1 now."); - - result->Clear(); - // must be setup landmarks_per_face before reserve. - // ultraface detector does not detect landmarks by default. - result->landmarks_per_face = 0; if (scores_tensor.dtype != FDDataType::FP32) { FDERROR << "Only support post process with float32 data." << std::endl; return false; @@ -120,6 +115,12 @@ bool UltraFace::Postprocess( return false; } + result->Clear(); + // must be setup landmarks_per_face before reserve. + // ultraface detector does not detect landmarks by default. + result->landmarks_per_face = 0; + result->Reserve(boxes_tensor.shape[1]); + float* scores_ptr = static_cast(scores_tensor.Data()); float* boxes_ptr = static_cast(boxes_tensor.Data()); const size_t num_bboxes = boxes_tensor.shape[1]; // e.g 4420 diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc index 08c4073f29..9c698976e0 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc @@ -130,8 +130,8 @@ bool PPYOLOE::Preprocess(Mat* mat, std::vector* outputs) { (*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name); float* ptr = static_cast((*outputs)[1].MutableData()); - ptr[0] = mat->Height() * 1.0 / mat->Height(); - ptr[1] = mat->Width() * 1.0 / mat->Width(); + ptr[0] = mat->Height() * 1.0 / origin_h; + ptr[1] = mat->Width() * 1.0 / origin_w; return true; } @@ -176,8 +176,7 @@ bool PPYOLOE::Postprocess(std::vector& infer_result, result->scores.push_back(nms.out_box_data[i * 6 + 1]); result->boxes.emplace_back(std::array{ nms.out_box_data[i * 6 + 2], nms.out_box_data[i * 6 + 3], - nms.out_box_data[i * 6 + 4] - nms.out_box_data[i * 6 + 2], - nms.out_box_data[i * 6 + 5] - nms.out_box_data[i * 6 + 3]}); + nms.out_box_data[i * 6 + 4], nms.out_box_data[i * 6 + 5]}); } } else { int box_num = 0; @@ -197,8 +196,7 @@ bool PPYOLOE::Postprocess(std::vector& infer_result, result->scores.push_back(box_data[i * 6 + 1]); result->boxes.emplace_back( std::array{box_data[i * 6 + 2], box_data[i * 6 + 3], - box_data[i * 6 + 4] - box_data[i * 6 + 2], - box_data[i * 6 + 5] - box_data[i * 6 + 3]}); + box_data[i * 6 + 4], box_data[i * 6 + 5]}); } } return true; diff --git a/csrcs/fastdeploy/vision/vision_pybind.cc b/csrcs/fastdeploy/vision/vision_pybind.cc index 87c5ce6d84..18b662e68e 100644 --- a/csrcs/fastdeploy/vision/vision_pybind.cc +++ b/csrcs/fastdeploy/vision/vision_pybind.cc @@ -26,6 +26,7 @@ void BindMegvii(pybind11::module& m); void BindDeepCam(pybind11::module& m); void BindRangiLyu(pybind11::module& m); void BindLinzaer(pybind11::module& m); +void BindBiubug6(pybind11::module& m); #ifdef ENABLE_VISION_VISUALIZE void BindVisualize(pybind11::module& m); #endif @@ -71,6 +72,7 @@ void BindVision(pybind11::module& m) { BindDeepCam(m); BindRangiLyu(m); BindLinzaer(m); + BindBiubug6(m); #ifdef ENABLE_VISION_VISUALIZE BindVisualize(m); #endif diff --git a/examples/vision/biubug6_retinaface.cc b/examples/vision/biubug6_retinaface.cc new file mode 100644 index 0000000000..65a396ff9b --- /dev/null +++ b/examples/vision/biubug6_retinaface.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = + "../resources/models/Pytorch_RetinaFace_resnet50-720-1080.onnx"; + std::string img_path = "../resources/images/test_face_det.jpg"; + std::string vis_path = + "../resources/outputs/biubug6_retinaface_vis_result.jpg"; + + auto model = vis::biubug6::RetinaFace(model_file); + model.size = {1080, 720}; // (width, height) + if (!model.Initialized()) { + std::cerr << "Init Failed! Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "Init Done! Model:" << model_file << std::endl; + } + model.EnableDebug(); + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im = im.clone(); + + vis::FaceDetectionResult res; + if (!model.Predict(&im, &res, 0.3f, 0.3f)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f); + cv::imwrite(vis_path, vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/external/onnxruntime.cmake b/external/onnxruntime.cmake index da2ce43684..241fc1aa2e 100644 --- a/external/onnxruntime.cmake +++ b/external/onnxruntime.cmake @@ -36,6 +36,9 @@ if(WIN32) else() set(ONNXRUNTIME_FILENAME "onnxruntime-win-x64-${ONNXRUNTIME_VERSION}.zip") endif() + if(NOT CMAKE_CL_64) + set(ONNXRUNTIME_FILENAME "onnxruntime-win-x86-${ONNXRUNTIME_VERSION}.zip") + endif() elseif(APPLE) if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64") set(ONNXRUNTIME_FILENAME "onnxruntime-osx-arm64-${ONNXRUNTIME_VERSION}.tgz") diff --git a/external/paddle2onnx.cmake b/external/paddle2onnx.cmake index 97ba169ac3..e226bc6c95 100644 --- a/external/paddle2onnx.cmake +++ b/external/paddle2onnx.cmake @@ -46,6 +46,9 @@ set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/paddle2onnx/libs/") set(PADDLE2ONNX_VERSION "1.0.0rc2") if(WIN32) set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") + if(NOT CMAKE_CL_64) + set(PADDLE2ONNX_FILE "paddle2onnx-win-x86-${PADDLE2ONNX_VERSION}.zip") + endif() elseif(APPLE) if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64") set(PADDLE2ONNX_FILE "paddle2onnx-osx-arm64-${PADDLE2ONNX_VERSION}.tgz") diff --git a/fastdeploy/vision/__init__.py b/fastdeploy/vision/__init__.py index 6387f5e39b..a362029832 100644 --- a/fastdeploy/vision/__init__.py +++ b/fastdeploy/vision/__init__.py @@ -25,3 +25,4 @@ from . import deepcam from . import rangilyu from . import linzaer +from . import biubug6 diff --git a/fastdeploy/vision/biubug6/__init__.py b/fastdeploy/vision/biubug6/__init__.py new file mode 100644 index 0000000000..c3772a47e3 --- /dev/null +++ b/fastdeploy/vision/biubug6/__init__.py @@ -0,0 +1,98 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from ... import FastDeployModel, Frontend +from ... import fastdeploy_main as C + + +class RetinaFace(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=Frontend.ONNX): + # 调用基函数进行backend_option的初始化 + # 初始化后的option保存在self._runtime_option + super(RetinaFace, self).__init__(runtime_option) + + self._model = C.vision.biubug6.RetinaFace( + model_file, params_file, self._runtime_option, model_format) + # 通过self.initialized判断整个模型的初始化是否成功 + assert self.initialized, "RetinaFace initialize failed." + + def predict(self, input_image, conf_threshold=0.7, nms_iou_threshold=0.3): + return self._model.predict(input_image, conf_threshold, + nms_iou_threshold) + + # 一些跟UltraFace模型有关的属性封装 + # 多数是预处理相关,可通过修改如model.size = [640, 480]改变预处理时resize的大小(前提是模型支持) + @property + def size(self): + return self._model.size + + @property + def variance(self): + return self._model.variance + + @property + def downsample_strides(self): + return self._model.downsample_strides + + @property + def min_sizes(self): + return self._model.min_sizes + + @property + def landmarks_per_face(self): + return self._model.landmarks_per_face + + @size.setter + def size(self, wh): + assert isinstance(wh, [list, tuple]),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh + + @variance.setter + def variance(self, value): + assert isinstance(v, [list, tuple]),\ + "The value to set `variance` must be type of tuple or list." + assert len(value) == 2,\ + "The value to set `variance` must contatins 2 elements".format( + len(value)) + self._model.variance = value + + @downsample_strides.setter + def downsample_strides(self, value): + assert isinstance( + value, + list), "The value to set `downsample_strides` must be type of list." + self._model.downsample_strides = value + + @min_sizes.setter + def min_sizes(self, value): + assert isinstance( + value, list), "The value to set `min_sizes` must be type of list." + self._model.min_sizes = value + + @landmarks_per_face.setter + def landmarks_per_face(self, value): + assert isinstance( + value, + int), "The value to set `landmarks_per_face` must be type of int." + self._model.landmarks_per_face = value diff --git a/model_zoo/vision/ppyoloe/ppyoloe.py b/model_zoo/vision/ppyoloe/ppyoloe.py index 7d79dfd8cf..a3b12c1dc6 100644 --- a/model_zoo/vision/ppyoloe/ppyoloe.py +++ b/model_zoo/vision/ppyoloe/ppyoloe.py @@ -14,7 +14,7 @@ # 预测图片 im = cv2.imread("000000014439_640x640.jpg") -result = model.predict(im, conf_threshold=0.5) +result = model.predict(im) # 可视化结果 fd.vision.visualize.vis_detection(im, result) diff --git a/model_zoo/vision/retinaface/README.md b/model_zoo/vision/retinaface/README.md new file mode 100644 index 0000000000..2b19027406 --- /dev/null +++ b/model_zoo/vision/retinaface/README.md @@ -0,0 +1,76 @@ +# RetinaFace部署示例 + +当前支持模型版本为:[RetinaFace CommitID:b984b4b](https://github.com/biubug6/Pytorch_Retinaface/commit/b984b4b) + +本文档说明如何进行[RetinaFace](https://github.com/biubug6/Pytorch_Retinaface)的快速部署推理。本目录结构如下 + +``` +. +├── cpp # C++ 代码目录 +│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件 +│   ├── README.md # C++ 代码编译部署文档 +│   └── retinaface.cc # C++ 示例代码 +├── api.md # API 说明文档 +├── README.md # RetinaFace 部署文档 +└── retinaface.py # Python示例代码 +``` + +## 安装FastDeploy + +使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu` +```bash +# 安装fastdeploy-python工具 +pip install fastdeploy-python + +# 安装vision-cpu模块 +fastdeploy install vision-cpu +``` + +## Python部署 + +执行如下代码即会自动下载RetinaFace模型和测试图片 +```bash +python retinaface.py +``` + +## 手动获取ONNX模型文件 +自动下载的模型文件是我们事先转换好的,如果您需要从RetinaFace官方repo导出ONNX,请参考以下步骤。 + +* 下载官方仓库并 +```bash +git clone https://github.com/biubug6/Pytorch_Retinaface.git +``` +* 下载预训练权重并放在weights文件夹 +```text +./weights/ + mobilenet0.25_Final.pth + mobilenetV1X0.25_pretrain.tar + Resnet50_Final.pth +``` +* 运行convert_to_onnx.py导出ONNX模型文件 +```bash +PYTHONPATH=. python convert_to_onnx.py --trained_model ./weights/mobilenet0.25_Final.pth --network mobile0.25 --long_side 640 --cpu +PYTHONPATH=. python convert_to_onnx.py --trained_model ./weights/Resnet50_Final.pth --network resnet50 --long_side 640 --cpu +``` +注意:需要先对convert_to_onnx.py脚本中的--long_side参数增加类型约束,type=int. +* 使用onnxsim对模型进行简化 +```bash +onnxsim FaceDetector.onnx Pytorch_RetinaFace_mobile0.25-640-640.onnx # mobilenet +onnxsim FaceDetector.onnx Pytorch_RetinaFace_resnet50-640-640.onnx # resnet50 +``` + + +执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下 +``` +FaceDetectionResult: [xmin, ymin, xmax, ymax, score, (x, y) x 5] +403.339783,254.192413, 490.002747, 351.931213, 0.999427, (425.657257,293.820740), (467.249451,293.667267), (446.830078,315.016388), (428.903381,326.129425), (465.764648,325.837341) +296.834564,181.992035, 384.516876, 277.461243, 0.999194, (313.605164,224.800110), (352.888977,219.088043), (333.530182,239.872787), (325.395203,255.463852), (358.417175,250.529892) +742.206238,263.547424, 840.871765, 366.171387, 0.999068, (762.715759,308.939880), (809.019653,304.544830), (786.174194,329.286163), (771.952271,341.376038), (812.717529,337.528839) +545.351685,228.015930, 635.423584, 335.458649, 0.998681, (559.295654,269.971619), (598.439758,273.823608), (567.496643,292.894348), (558.160034,306.637238), (592.175781,309.493591) +180.078125,241.787888, 257.213135, 320.321777, 0.998342, (203.702591,272.032715), (237.497726,271.356445), (222.380402,288.225708), (208.015259,301.360352), (233.943451,300.801636) +``` + +## 其它文档 + +- [C++部署](./cpp/README.md) +- [RetinaFace API文档](./api.md) diff --git a/model_zoo/vision/retinaface/api.md b/model_zoo/vision/retinaface/api.md new file mode 100644 index 0000000000..47afddc872 --- /dev/null +++ b/model_zoo/vision/retinaface/api.md @@ -0,0 +1,71 @@ +# RetinaFace API说明 + +## Python API + +### RetinaFace类 +``` +fastdeploy.vision.biubug6.RetinaFace(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX) +``` +RetinaFace模型加载和初始化,当model_format为`fd.Frontend.ONNX`时,只需提供model_file,如`Pytorch_RetinaFace_mobile0.25-640-640.onnx`;当model_format为`fd.Frontend.PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### predict函数 +> ``` +> RetinaFace.predict(image_data, conf_threshold=0.7, nms_iou_threshold=0.3) +> ``` +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值 +> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 + +示例代码参考[retinaface.py](./retinaface.py) + + +## C++ API + +### RetinaFace 类 +``` +fastdeploy::vision::biubug6::RetinaFace( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` +RetinaFace模型加载和初始化,当model_format为`Frontend::ONNX`时,只需提供model_file,如`Pytorch_RetinaFace_mobile0.25-640-640.onnx`;当model_format为`Frontend::PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### Predict函数 +> ``` +> RetinaFace::Predict(cv::Mat* im, FaceDetectionResult* result, +> float conf_threshold = 0.7, +> float nms_iou_threshold = 0.3) +> ``` +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度 +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值 + +示例代码参考[cpp/retinaface.cc](cpp/retinaface.cc) + +## 其它API使用 + +- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md) diff --git a/model_zoo/vision/retinaface/cpp/CMakeLists.txt b/model_zoo/vision/retinaface/cpp/CMakeLists.txt new file mode 100644 index 0000000000..7ca567b828 --- /dev/null +++ b/model_zoo/vision/retinaface/cpp/CMakeLists.txt @@ -0,0 +1,17 @@ +PROJECT(retinaface_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.16) + +# 在低版本ABI环境中,通过如下代码进行兼容性编译 +# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/) + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(retinaface_demo ${PROJECT_SOURCE_DIR}/retinaface.cc) +# 添加FastDeploy库依赖 +target_link_libraries(retinaface_demo ${FASTDEPLOY_LIBS}) diff --git a/model_zoo/vision/retinaface/cpp/README.md b/model_zoo/vision/retinaface/cpp/README.md new file mode 100644 index 0000000000..ba400b5704 --- /dev/null +++ b/model_zoo/vision/retinaface/cpp/README.md @@ -0,0 +1,61 @@ +# 编译RetinaFace示例 + +当前支持模型版本为:[RetinaFace CommitID:b984b4b](https://github.com/biubug6/Pytorch_Retinaface/commit/b984b4b) + +## 下载和解压预测库 +```bash +wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz +tar xvf fastdeploy-linux-x64-0.0.3.tgz +``` + +## 编译示例代码 +```bash +mkdir build & cd build +cmake .. +make -j +``` + +## 下载模型和图片 +wget https://github.com/DefTruth/Pytorch_Retinaface/releases/download/v0.1/Pytorch_RetinaFace_mobile0.25-640-640.onnx +wget https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/raw/master/imgs/3.jpg + +## 手动获取ONNX模型文件 +自动下载的模型文件是我们事先转换好的,如果您需要从RetinaFace官方repo导出ONNX,请参考以下步骤。 + +* 下载官方仓库并 +```bash +git clone https://github.com/biubug6/Pytorch_Retinaface.git +``` +* 下载预训练权重并放在weights文件夹 +```text +./weights/ + mobilenet0.25_Final.pth + mobilenetV1X0.25_pretrain.tar + Resnet50_Final.pth +``` +* 运行convert_to_onnx.py导出ONNX模型文件 +```bash +PYTHONPATH=. python convert_to_onnx.py --trained_model ./weights/mobilenet0.25_Final.pth --network mobile0.25 --long_side 640 --cpu +PYTHONPATH=. python convert_to_onnx.py --trained_model ./weights/Resnet50_Final.pth --network resnet50 --long_side 640 --cpu +``` +注意:需要先对convert_to_onnx.py脚本中的--long_side参数增加类型约束,type=int. +* 使用onnxsim对模型进行简化 +```bash +onnxsim FaceDetector.onnx Pytorch_RetinaFace_mobile0.25-640-640.onnx # mobilenet +onnxsim FaceDetector.onnx Pytorch_RetinaFace_resnet50-640-640.onnx # resnet50 +``` + +## 执行 +```bash +./retinaface_demo +``` + +执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 +``` +FaceDetectionResult: [xmin, ymin, xmax, ymax, score, (x, y) x 5] +403.339783,254.192413, 490.002747, 351.931213, 0.999427, (425.657257,293.820740), (467.249451,293.667267), (446.830078,315.016388), (428.903381,326.129425), (465.764648,325.837341) +296.834564,181.992035, 384.516876, 277.461243, 0.999194, (313.605164,224.800110), (352.888977,219.088043), (333.530182,239.872787), (325.395203,255.463852), (358.417175,250.529892) +742.206238,263.547424, 840.871765, 366.171387, 0.999068, (762.715759,308.939880), (809.019653,304.544830), (786.174194,329.286163), (771.952271,341.376038), (812.717529,337.528839) +545.351685,228.015930, 635.423584, 335.458649, 0.998681, (559.295654,269.971619), (598.439758,273.823608), (567.496643,292.894348), (558.160034,306.637238), (592.175781,309.493591) +180.078125,241.787888, 257.213135, 320.321777, 0.998342, (203.702591,272.032715), (237.497726,271.356445), (222.380402,288.225708), (208.015259,301.360352), (233.943451,300.801636) +``` diff --git a/model_zoo/vision/retinaface/cpp/retinaface.cc b/model_zoo/vision/retinaface/cpp/retinaface.cc new file mode 100644 index 0000000000..933b629c4a --- /dev/null +++ b/model_zoo/vision/retinaface/cpp/retinaface.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + auto model = + vis::biubug6::RetinaFace("Pytorch_RetinaFace_mobile0.25-640-640.onnx"); + if (!model.Initialized()) { + std::cerr << "Init Failed! Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "Init Done! Model:" << model_file << std::endl; + } + model.EnableDebug(); + + cv::Mat im = cv::imread("3.jpg"); + cv::Mat vis_im = im.clone(); + + vis::FaceDetectionResult res; + if (!model.Predict(&im, &res, 0.7f, 0.3f)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/model_zoo/vision/retinaface/retinaface.py b/model_zoo/vision/retinaface/retinaface.py new file mode 100644 index 0000000000..4e5a123c2e --- /dev/null +++ b/model_zoo/vision/retinaface/retinaface.py @@ -0,0 +1,24 @@ +import fastdeploy as fd +import cv2 + +# 下载模型 +model_url = "https://github.com/DefTruth/Pytorch_Retinaface/releases/download/v0.1/Pytorch_RetinaFace_mobile0.25-640-640.onnx" +test_img_url = "https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/raw/master/imgs/3.jpg" +fd.download(model_url, ".", show_progress=True) +fd.download(test_img_url, ".", show_progress=True) + +# 加载模型 +model = fd.vision.biubug6.RetinaFace( + "Pytorch_RetinaFace_mobile0.25-640-640.onnx") + +# 预测图片 +im = cv2.imread("3.jpg") +result = model.predict(im, conf_threshold=0.7, nms_iou_threshold=0.3) + +# 可视化结果 +fd.vision.visualize.vis_face_detection(im, result) +cv2.imwrite("vis_result.jpg", im) + +# 输出预测结果 +print(result) +print(model.runtime_option) diff --git a/model_zoo/vision/ultraface/README.md b/model_zoo/vision/ultraface/README.md index 988f607426..264f1b5cb2 100644 --- a/model_zoo/vision/ultraface/README.md +++ b/model_zoo/vision/ultraface/README.md @@ -28,7 +28,7 @@ fastdeploy install vision-cpu ## Python部署 -执行如下代码即会自动下载YOLOv5Face模型和测试图片 +执行如下代码即会自动下载UltraFace模型和测试图片 ```bash python ultraface.py ``` diff --git a/model_zoo/vision/yolov5face/api.md b/model_zoo/vision/yolov5face/api.md index 384ef23d31..ea32820f6b 100644 --- a/model_zoo/vision/yolov5face/api.md +++ b/model_zoo/vision/yolov5face/api.md @@ -51,7 +51,7 @@ YOLOv5Face模型加载和初始化,当model_format为`Frontend::ONNX`时,只 #### Predict函数 > ``` -> YOLOv5Face::Predict(cv::Mat* im, DetectionResult* result, +> YOLOv5Face::Predict(cv::Mat* im, FaceDetectionResult* result, > float conf_threshold = 0.25, > float nms_iou_threshold = 0.5) > ``` diff --git a/model_zoo/vision/yolov7/README.md b/model_zoo/vision/yolov7/README.md index 8b2f06d761..a7165a0455 100644 --- a/model_zoo/vision/yolov7/README.md +++ b/model_zoo/vision/yolov7/README.md @@ -26,9 +26,12 @@ #下载yolov7模型文件 wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt - # 导出onnx格式文件 + # 导出onnx格式文件 (Tips: 对应 YOLOv7 release v0.1 代码) python models/export.py --grid --dynamic --weights PATH/TO/yolov7.pt + # 如果您的代码版本中有支持NMS的ONNX文件导出,请使用如下命令导出ONNX文件(请暂时不要使用 "--end2end",我们后续将支持带有NMS的ONNX模型的部署) + python export.py --grid --dynamic --weights PATH/TO/yolov7.pt + # 移动onnx文件到demo目录 cp PATH/TO/yolov7.onnx PATH/TO/model_zoo/vision/yolov7/ ``` diff --git a/model_zoo/vision/yolov7/cpp/README.md b/model_zoo/vision/yolov7/cpp/README.md index 655e98678c..6190b3ae7b 100644 --- a/model_zoo/vision/yolov7/cpp/README.md +++ b/model_zoo/vision/yolov7/cpp/README.md @@ -12,9 +12,11 @@ #下载yolov7模型文件 wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt - # 导出onnx格式文件 + # 导出onnx格式文件 (Tips: 对应 YOLOv7 release v0.1 代码) python models/export.py --grid --dynamic --weights PATH/TO/yolov7.pt + # 如果您的代码版本中有支持NMS的ONNX文件导出,请使用如下命令导出ONNX文件(请暂时不要使用 "--end2end",我们后续将支持带有NMS的ONNX模型的部署) + python export.py --grid --dynamic --weights PATH/TO/yolov7.pt ``` diff --git a/setup.py b/setup.py index 8575c42963..7c549fe604 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,8 @@ setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED") setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda") +if os.getenv("CMAKE_CXX_COMPILER", None) is not None: + setup_configs["CMAKE_CXX_COMPILER"] = os.getenv("CMAKE_CXX_COMPILER") TOP_DIR = os.path.realpath(os.path.dirname(__file__)) SRC_DIR = os.path.join(TOP_DIR, "fastdeploy")