diff --git a/.gitignore b/.gitignore index 967c01a0d9..f49b53b0e1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ fastdeploy/LICENSE* fastdeploy/ThirdPartyNotices* *.so* fastdeploy/libs/third_libs +csrcs/fastdeploy/core/config.h \ No newline at end of file diff --git a/csrcs/fastdeploy/vision.h b/csrcs/fastdeploy/vision.h index a64e974779..781bf01a2f 100644 --- a/csrcs/fastdeploy/vision.h +++ b/csrcs/fastdeploy/vision.h @@ -16,6 +16,7 @@ #include "fastdeploy/core/config.h" #ifdef ENABLE_VISION #include "fastdeploy/vision/deepcam/yolov5face.h" +#include "fastdeploy/vision/linzaer/ultraface.h" #include "fastdeploy/vision/megvii/yolox.h" #include "fastdeploy/vision/meituan/yolov6.h" #include "fastdeploy/vision/ppcls/model.h" diff --git a/csrcs/fastdeploy/vision/linzaer/linzaer_pybind.cc b/csrcs/fastdeploy/vision/linzaer/linzaer_pybind.cc new file mode 100644 index 0000000000..89751bca21 --- /dev/null +++ b/csrcs/fastdeploy/vision/linzaer/linzaer_pybind.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindLinzaer(pybind11::module& m) { + auto linzaer_module = m.def_submodule( + "linzaer", + "https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB"); + pybind11::class_(linzaer_module, + "UltraFace") + .def(pybind11::init()) + .def("predict", + [](vision::linzaer::UltraFace& self, pybind11::array& data, + float conf_threshold, float nms_iou_threshold) { + auto mat = PyArrayToCvMat(data); + vision::FaceDetectionResult res; + self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); + return res; + }) + .def_readwrite("size", &vision::linzaer::UltraFace::size); +} +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/linzaer/ultraface.cc b/csrcs/fastdeploy/vision/linzaer/ultraface.cc new file mode 100644 index 0000000000..e9148604f2 --- /dev/null +++ b/csrcs/fastdeploy/vision/linzaer/ultraface.cc @@ -0,0 +1,220 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/linzaer/ultraface.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { + +namespace vision { + +namespace linzaer { + +UltraFace::UltraFace(const std::string& model_file, + const std::string& params_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + if (model_format == Frontend::ONNX) { + valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 + valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool UltraFace::Initialize() { + // parameters for preprocess + size = {320, 240}; + + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + // Check if the input shape is dynamic after Runtime already initialized, + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + return true; +} + +bool UltraFace::Preprocess( + Mat* mat, FDTensor* output, + std::map>* im_info) { + // ultraface's preprocess steps + // 1. resize + // 2. BGR->RGB + // 3. HWC->CHW + int resize_w = size[0]; + int resize_h = size[1]; + if (resize_h != mat->Height() || resize_w != mat->Width()) { + Resize::Run(mat, resize_w, resize_h); + } + + BGR2RGB::Run(mat); + // Compute `result = mat * alpha + beta` directly by channel + // Reference: detect_imgs_onnx.py#L73 + std::vector alpha = {1.0f / 128.0f, 1.0f / 128.0f, 1.0f / 128.0f}; + std::vector beta = {-127.0f * (1.0f / 128.0f), + -127.0f * (1.0f / 128.0f), + -127.0f * (1.0f / 128.0f)}; // RGB; + Convert::Run(mat, alpha, beta); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +bool UltraFace::Postprocess( + std::vector& infer_result, FaceDetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold) { + // ultraface has 2 output tensors, scores & boxes + FDASSERT( + (infer_result.size() == 2), + "The default number of output tensor must be 2 according to ultraface."); + FDTensor& scores_tensor = infer_result.at(0); // (1,4420,2) + FDTensor& boxes_tensor = infer_result.at(1); // (1,4420,4) + FDASSERT((scores_tensor.shape[0] == 1), "Only support batch =1 now."); + FDASSERT((boxes_tensor.shape[0] == 1), "Only support batch =1 now."); + + result->Clear(); + // must be setup landmarks_per_face before reserve. + // ultraface detector does not detect landmarks by default. + result->landmarks_per_face = 0; + if (scores_tensor.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + if (boxes_tensor.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + + float* scores_ptr = static_cast(scores_tensor.Data()); + float* boxes_ptr = static_cast(boxes_tensor.Data()); + const size_t num_bboxes = boxes_tensor.shape[1]; // e.g 4420 + // fetch original image shape + auto iter_ipt = im_info.find("input_shape"); + FDASSERT((iter_ipt != im_info.end()), + "Cannot find input_shape from im_info."); + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + + // decode bounding boxes + for (size_t i = 0; i < num_bboxes; ++i) { + float confidence = scores_ptr[2 * i + 1]; + // filter boxes by conf_threshold + if (confidence <= conf_threshold) { + continue; + } + float x1 = boxes_ptr[4 * i + 0] * ipt_w; + float y1 = boxes_ptr[4 * i + 1] * ipt_h; + float x2 = boxes_ptr[4 * i + 2] * ipt_w; + float y2 = boxes_ptr[4 * i + 3] * ipt_h; + result->boxes.emplace_back(std::array{x1, y1, x2, y2}); + result->scores.push_back(confidence); + } + + if (result->boxes.size() == 0) { + return true; + } + + utils::NMS(result, nms_iou_threshold); + + // scale and clip box + for (size_t i = 0; i < result->boxes.size(); ++i) { + result->boxes[i][0] = std::max(result->boxes[i][0], 0.0f); + result->boxes[i][1] = std::max(result->boxes[i][1], 0.0f); + result->boxes[i][2] = std::max(result->boxes[i][2], 0.0f); + result->boxes[i][3] = std::max(result->boxes[i][3], 0.0f); + result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f); + result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f); + result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f); + result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f); + } + return true; +} + +bool UltraFace::Predict(cv::Mat* im, FaceDetectionResult* result, + float conf_threshold, float nms_iou_threshold) { +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_START(0) +#endif + + Mat mat(*im); + std::vector input_tensors(1); + + std::map> im_info; + + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + im_info["output_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(0, "Preprocess") + TIMERECORD_START(1) +#endif + + input_tensors[0].name = InputInfoOfRuntime(0).name; + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(1, "Inference") + TIMERECORD_START(2) +#endif + + if (!Postprocess(output_tensors, result, im_info, conf_threshold, + nms_iou_threshold)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(2, "Postprocess") +#endif + return true; +} + +} // namespace linzaer +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/linzaer/ultraface.h b/csrcs/fastdeploy/vision/linzaer/ultraface.h new file mode 100644 index 0000000000..c3e499d99e --- /dev/null +++ b/csrcs/fastdeploy/vision/linzaer/ultraface.h @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { + +namespace vision { + +namespace linzaer { + +class FASTDEPLOY_DECL UltraFace : public FastDeployModel { + public: + // 当model_format为ONNX时,无需指定params_file + // 当model_format为Paddle时,则需同时指定model_file & params_file + UltraFace(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX); + + // 定义模型的名称 + std::string ModelName() const { + return "Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB"; + } + + // 模型预测接口,即用户调用的接口 + // im 为用户的输入数据,目前对于CV均定义为cv::Mat + // result 为模型预测的输出结构体 + // conf_threshold 为后处理的参数 + // nms_iou_threshold 为后处理的参数 + virtual bool Predict(cv::Mat* im, FaceDetectionResult* result, + float conf_threshold = 0.7f, + float nms_iou_threshold = 0.3f); + + // 以下为模型在预测时的一些参数,基本是前后处理所需 + // 用户在创建模型后,可根据模型的要求,以及自己的需求 + // 对参数进行修改 + // tuple of (width, height), default (320, 240) + std::vector size; + + private: + // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 + bool Initialize(); + + // 输入图像预处理操作 + // Mat为FastDeploy定义的数据结构 + // FDTensor为预处理后的Tensor数据,传给后端进行推理 + // im_info为预处理过程保存的数据,在后处理中需要用到 + bool Preprocess(Mat* mat, FDTensor* outputs, + std::map>* im_info); + + // 后端推理结果后处理,输出给用户 + // infer_result 为后端推理后的输出Tensor + // result 为模型预测的结果 + // im_info 为预处理记录的信息,后处理用于还原box + // conf_threshold 后处理时过滤box的置信度阈值 + // nms_iou_threshold 后处理时NMS设定的iou阈值 + bool Postprocess(std::vector& infer_result, + FaceDetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold); + + // 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致 + bool IsDynamicInput() const { return is_dynamic_input_; } + + bool is_dynamic_input_; +}; + +} // namespace linzaer +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/vision_pybind.cc b/csrcs/fastdeploy/vision/vision_pybind.cc index 3b426ebd8c..87c5ce6d84 100644 --- a/csrcs/fastdeploy/vision/vision_pybind.cc +++ b/csrcs/fastdeploy/vision/vision_pybind.cc @@ -25,6 +25,7 @@ void BindMeituan(pybind11::module& m); void BindMegvii(pybind11::module& m); void BindDeepCam(pybind11::module& m); void BindRangiLyu(pybind11::module& m); +void BindLinzaer(pybind11::module& m); #ifdef ENABLE_VISION_VISUALIZE void BindVisualize(pybind11::module& m); #endif @@ -69,6 +70,7 @@ void BindVision(pybind11::module& m) { BindMegvii(m); BindDeepCam(m); BindRangiLyu(m); + BindLinzaer(m); #ifdef ENABLE_VISION_VISUALIZE BindVisualize(m); #endif diff --git a/examples/vision/linzaer_ultraface.cc b/examples/vision/linzaer_ultraface.cc new file mode 100644 index 0000000000..eb1cbafe8c --- /dev/null +++ b/examples/vision/linzaer_ultraface.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = "../resources/models/version-RFB-320.onnx"; + std::string img_path = "../resources/images/test_face_det_0.jpg"; + std::string vis_path = + "../resources/outputs/linzaer_ultraface_vis_result.jpg"; + + auto model = vis::linzaer::UltraFace(model_file); + if (!model.Initialized()) { + std::cerr << "Init Failed! Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "Init Done! Model:" << model_file << std::endl; + } + model.EnableDebug(); + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im = im.clone(); + + vis::FaceDetectionResult res; + if (!model.Predict(&im, &res, 0.7f, 0.3f)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f); + cv::imwrite(vis_path, vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/fastdeploy/vision/__init__.py b/fastdeploy/vision/__init__.py index f9c9423701..6387f5e39b 100644 --- a/fastdeploy/vision/__init__.py +++ b/fastdeploy/vision/__init__.py @@ -24,3 +24,4 @@ from . import wongkinyiu from . import deepcam from . import rangilyu +from . import linzaer diff --git a/fastdeploy/vision/evaluation/__init__.py b/fastdeploy/vision/evaluation/__init__.py index 1158095ec5..d2c9a79116 100644 --- a/fastdeploy/vision/evaluation/__init__.py +++ b/fastdeploy/vision/evaluation/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from __future__ import absolute_import from .classify import eval_classify +from .detection import eval_detection diff --git a/fastdeploy/vision/evaluation/detection.py b/fastdeploy/vision/evaluation/detection.py new file mode 100644 index 0000000000..4aaaaaaa56 --- /dev/null +++ b/fastdeploy/vision/evaluation/detection.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import trange +import cv2 +import numpy as np +from .utils import CocoDetection +from .utils import COCOMetric +import copy +import collections + + +def eval_detection(model, + conf_threshold, + nms_iou_threshold, + data_dir, + ann_file, + plot=False): + assert isinstance(conf_threshold, ( + float, int + )), "The conf_threshold:{} need to be int or float".format(conf_threshold) + assert isinstance(nms_iou_threshold, ( + float, + int)), "The nms_iou_threshold:{} need to be int or float".format( + nms_iou_threshold) + eval_dataset = CocoDetection( + data_dir=data_dir, ann_file=ann_file, shuffle=False) + all_image_info = eval_dataset.file_list + image_num = eval_dataset.num_samples + eval_dataset.data_fields = { + 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class', 'is_crowd' + } + eval_metric = COCOMetric( + coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False) + scores = collections.OrderedDict() + for image_info, i in zip(all_image_info, + trange( + image_num, desc="Inference Progress")): + im = cv2.imread(image_info["image"]) + im_id = image_info["im_id"] + result = model.predict(im, conf_threshold, nms_iou_threshold) + pred = { + 'bbox': + [[c] + [s] + b + for b, s, c in zip(result.boxes, result.scores, result.label_ids) + ], + 'bbox_num': len(result.boxes), + 'im_id': im_id + } + eval_metric.update(im_id, pred) + eval_metric.accumulate() + eval_details = eval_metric.details + scores.update(eval_metric.get()) + eval_metric.reset() + return scores diff --git a/fastdeploy/vision/evaluation/utils/__init__.py b/fastdeploy/vision/evaluation/utils/__init__.py new file mode 100644 index 0000000000..dfcb419bad --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import fd_logging +from .util import * +from .metrics import * +from .json_results import * +from .map_utils import * +from .coco_utils import * +from .coco import * +from .cityscapes import Cityscapes diff --git a/fastdeploy/vision/evaluation/utils/coco.py b/fastdeploy/vision/evaluation/utils/coco.py new file mode 100644 index 0000000000..c675790557 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/coco.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import copy +import os.path as osp +import six +import sys +import numpy as np +from . import fd_logging as logging +from .util import is_pic, get_num_workers + + +class CocoDetection(object): + """读取MSCOCO格式的检测数据集,并对样本进行相应的处理,该格式的数据集同样可以应用到实例分割模型的训练中。 + + Args: + data_dir (str): 数据集所在的目录路径。 + ann_file (str): 数据集的标注文件,为一个独立的json格式文件。 + num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据 + 系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。 + shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。 + allow_empty (bool): 是否加载负样本。默认为False。 + empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。 + """ + + def __init__(self, + data_dir, + ann_file, + num_workers='auto', + shuffle=False, + allow_empty=False, + empty_ratio=1.): + + from pycocotools.coco import COCO + self.data_dir = data_dir + self.data_fields = None + self.num_max_boxes = 1000 + self.num_workers = get_num_workers(num_workers) + self.shuffle = shuffle + self.allow_empty = allow_empty + self.empty_ratio = empty_ratio + self.file_list = list() + neg_file_list = list() + self.labels = list() + + coco = COCO(ann_file) + self.coco_gt = coco + img_ids = sorted(coco.getImgIds()) + cat_ids = coco.getCatIds() + catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) + cname2clsid = dict({ + coco.loadCats(catid)[0]['name']: clsid + for catid, clsid in catid2clsid.items() + }) + for label, cid in sorted(cname2clsid.items(), key=lambda d: d[1]): + self.labels.append(label) + logging.info("Starting to read file list from dataset...") + + ct = 0 + for img_id in img_ids: + is_empty = False + img_anno = coco.loadImgs(img_id)[0] + im_fname = osp.join(data_dir, img_anno['file_name']) + if not is_pic(im_fname): + continue + im_w = float(img_anno['width']) + im_h = float(img_anno['height']) + ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) + instances = coco.loadAnns(ins_anno_ids) + + bboxes = [] + for inst in instances: + x, y, box_w, box_h = inst['bbox'] + x1 = max(0, x) + y1 = max(0, y) + x2 = min(im_w - 1, x1 + max(0, box_w)) + y2 = min(im_h - 1, y1 + max(0, box_h)) + if inst['area'] > 0 and x2 >= x1 and y2 >= y1: + inst['clean_bbox'] = [x1, y1, x2, y2] + bboxes.append(inst) + else: + logging.warning( + "Found an invalid bbox in annotations: " + "im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}." + .format(img_id, float(inst['area']), x1, y1, x2, y2)) + num_bbox = len(bboxes) + if num_bbox == 0 and not self.allow_empty: + continue + elif num_bbox == 0: + is_empty = True + + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + gt_score = np.ones((num_bbox, 1), dtype=np.float32) + is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) + difficult = np.zeros((num_bbox, 1), dtype=np.int32) + gt_poly = [None] * num_bbox + + has_segmentation = False + for i, box in reversed(list(enumerate(bboxes))): + catid = box['category_id'] + gt_class[i][0] = catid2clsid[catid] + gt_bbox[i, :] = box['clean_bbox'] + is_crowd[i][0] = box['iscrowd'] + if 'segmentation' in box and box['iscrowd'] == 1: + gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] + elif 'segmentation' in box and box['segmentation']: + if not np.array( + box['segmentation'], + dtype=object).size > 0 and not self.allow_empty: + gt_poly.pop(i) + is_crowd = np.delete(is_crowd, i) + gt_class = np.delete(gt_class, i) + gt_bbox = np.delete(gt_bbox, i) + else: + gt_poly[i] = box['segmentation'] + has_segmentation = True + if has_segmentation and not any(gt_poly) and not self.allow_empty: + continue + + im_info = { + 'im_id': np.array([img_id]).astype('int32'), + 'image_shape': np.array([im_h, im_w]).astype('int32'), + } + label_info = { + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + 'gt_score': gt_score, + 'gt_poly': gt_poly, + 'difficult': difficult + } + + if is_empty: + neg_file_list.append({ + 'image': im_fname, + ** + im_info, + ** + label_info + }) + else: + self.file_list.append({ + 'image': im_fname, + ** + im_info, + ** + label_info + }) + ct += 1 + + self.num_max_boxes = max(self.num_max_boxes, len(instances)) + + if not ct: + logging.error( + "No coco record found in %s' % (ann_file)", exit=True) + self.pos_num = len(self.file_list) + if self.allow_empty and neg_file_list: + self.file_list += self._sample_empty(neg_file_list) + logging.info( + "{} samples in file {}, including {} positive samples and {} negative samples.". + format( + len(self.file_list), ann_file, self.pos_num, + len(self.file_list) - self.pos_num)) + self.num_samples = len(self.file_list) + + self._epoch = 0 diff --git a/fastdeploy/vision/evaluation/utils/coco_utils.py b/fastdeploy/vision/evaluation/utils/coco_utils.py new file mode 100644 index 0000000000..9d551f253f --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/coco_utils.py @@ -0,0 +1,217 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import numpy as np +import itertools +from .map_utils import draw_pr_curve +from .json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res +import logging as logging +import copy + + +def loadRes(coco_obj, anns): + """ + Load result file and return a result api object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + + # This function has the same functionality as pycocotools.COCO.loadRes, + # except that the input anns is list of results rather than a json file. + # Refer to + # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305, + + # matplotlib.use() must be called *before* pylab, matplotlib.pyplot, + # or matplotlib.backends is imported for the first time + # pycocotools import matplotlib + import matplotlib + matplotlib.use('Agg') + from pycocotools.coco import COCO + import pycocotools.mask as maskUtils + import time + res = COCO() + res.dataset['images'] = [img for img in coco_obj.dataset['images']] + + tic = time.time() + assert type(anns) == list, 'results in not an array of objects' + annsImgIds = [ann['image_id'] for ann in anns] + assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \ + 'Results do not correspond to current coco set' + if 'caption' in anns[0]: + imgIds = set([img['id'] for img in res.dataset['images']]) & set( + [ann['image_id'] for ann in anns]) + res.dataset['images'] = [ + img for img in res.dataset['images'] if img['id'] in imgIds + ] + for id, ann in enumerate(anns): + ann['id'] = id + 1 + elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: + res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[ + 'categories']) + for id, ann in enumerate(anns): + bb = ann['bbox'] + x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] + if not 'segmentation' in ann: + ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann['area'] = bb[2] * bb[3] + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'segmentation' in anns[0]: + res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[ + 'categories']) + for id, ann in enumerate(anns): + # now only support compressed RLE format as segmentation results + ann['area'] = maskUtils.area(ann['segmentation']) + if not 'bbox' in ann: + ann['bbox'] = maskUtils.toBbox(ann['segmentation']) + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'keypoints' in anns[0]: + res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[ + 'categories']) + for id, ann in enumerate(anns): + s = ann['keypoints'] + x = s[0::3] + y = s[1::3] + x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y) + ann['area'] = (x1 - x0) * (y1 - y0) + ann['id'] = id + 1 + ann['bbox'] = [x0, y0, x1 - x0, y1 - y0] + + res.dataset['annotations'] = anns + res.createIndex() + return res + + +def get_infer_results(outs, catid, bias=0): + """ + Get result at the stage of inference. + The output format is dictionary containing bbox or mask result. + + For example, bbox result is a list and each element contains + image_id, category_id, bbox and score. + """ + if outs is None or len(outs) == 0: + raise ValueError( + 'The number of valid detection result if zero. Please use reasonable model and check input data.' + ) + + im_id = outs['im_id'] + + infer_res = {} + if 'bbox' in outs: + if len(outs['bbox']) > 0 and len(outs['bbox'][0]) > 6: + infer_res['bbox'] = get_det_poly_res( + outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias) + else: + infer_res['bbox'] = get_det_res( + outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias) + + if 'mask' in outs: + # mask post process + infer_res['mask'] = get_seg_res(outs['mask'], outs['bbox'], + outs['bbox_num'], im_id, catid) + + if 'segm' in outs: + infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid) + + return infer_res + + +def cocoapi_eval(anns, + style, + coco_gt=None, + anno_file=None, + max_dets=(100, 300, 1000), + classwise=False): + """ + Args: + anns: Evaluation result. + style (str): COCOeval style, can be `bbox` , `segm` and `proposal`. + coco_gt (str): Whether to load COCOAPI through anno_file, + eg: coco_gt = COCO(anno_file) + anno_file (str): COCO annotations file. + max_dets (tuple): COCO evaluation maxDets. + classwise (bool): Whether per-category AP and draw P-R Curve or not. + """ + assert coco_gt is not None or anno_file is not None + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + if coco_gt is None: + coco_gt = COCO(anno_file) + logging.info("Start evaluate...") + coco_dt = loadRes(coco_gt, anns) + if style == 'proposal': + coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') + coco_eval.params.useCats = 0 + coco_eval.params.maxDets = list(max_dets) + else: + coco_eval = COCOeval(coco_gt, coco_dt, style) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if classwise: + # Compute per-category AP and PR curve + try: + from terminaltables import AsciiTable + except Exception as e: + logging.error( + 'terminaltables not found, plaese install terminaltables. ' + 'for example: `pip install terminaltables`.') + raise e + precisions = coco_eval.eval['precision'] + cat_ids = coco_gt.getCatIds() + # precision: (iou, recall, cls, area range, max dets) + assert len(cat_ids) == precisions.shape[2] + results_per_category = [] + for idx, catId in enumerate(cat_ids): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = coco_gt.loadCats(catId)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + results_per_category.append( + (str(nm["name"]), '{:0.3f}'.format(float(ap)))) + pr_array = precisions[0, :, idx, 0, 2] + recall_array = np.arange(0.0, 1.01, 0.01) + draw_pr_curve( + pr_array, + recall_array, + out_dir=style + '_pr_curve', + file_name='{}_precision_recall_curve.jpg'.format(nm["name"])) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + headers = ['category', 'AP'] * (num_columns // 2) + results_2d = itertools.zip_longest( + * [results_flatten[i::num_columns] for i in range(num_columns)]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + logging.info('Per-category of {} AP: \n{}'.format(style, table.table)) + logging.info("per-category PR curve has output to {} folder.".format( + style + '_pr_curve')) + # flush coco evaluation result + sys.stdout.flush() + return coco_eval.stats diff --git a/fastdeploy/vision/evaluation/utils/fd_logging.py b/fastdeploy/vision/evaluation/utils/fd_logging.py new file mode 100644 index 0000000000..12091a4f75 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/fd_logging.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import os +import sys +import colorama +from colorama import init + +init(autoreset=True) +levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} + + +def log(level=2, message="", use_color=False): + current_time = time.time() + time_array = time.localtime(current_time) + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) + if use_color: + print("\033[1;31;40m{} [{}]\t{}\033[0m".format(current_time, levels[ + level], message).encode("utf-8").decode("latin1")) + else: + print("{} [{}]\t{}".format(current_time, levels[level], message) + .encode("utf-8").decode("latin1")) + sys.stdout.flush() + + +def debug(message="", use_color=False): + log(level=3, message=message, use_color=use_color) + + +def info(message="", use_color=False): + log(level=2, message=message, use_color=use_color) + + +def warning(message="", use_color=True): + log(level=1, message=message, use_color=use_color) + + +def error(message="", use_color=True, exit=True): + log(level=0, message=message, use_color=use_color) + if exit: + sys.exit(-1) diff --git a/fastdeploy/vision/evaluation/utils/json_results.py b/fastdeploy/vision/evaluation/utils/json_results.py new file mode 100644 index 0000000000..b2e816025b --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/json_results.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import numpy as np + + +def get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): + det_res = [] + for i in range(bbox_nums): + cur_image_id = int(image_id) + dt = bboxes[i] + num_id, score, xmin, ymin, xmax, ymax = dt + if int(num_id) < 0: + continue + category_id = label_to_cat_id_map[int(num_id)] + w = xmax - xmin + bias + h = ymax - ymin + bias + bbox = [xmin, ymin, w, h] + dt_res = { + 'image_id': cur_image_id, + 'category_id': category_id, + 'bbox': bbox, + 'score': score + } + det_res.append(dt_res) + return det_res + + +def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): + det_res = [] + k = 0 + for i in range(len(bbox_nums)): + cur_image_id = int(image_id[i][0]) + det_nums = bbox_nums[i] + for j in range(det_nums): + dt = bboxes[k] + k = k + 1 + num_id, score, x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist() + if int(num_id) < 0: + continue + category_id = label_to_cat_id_map[int(num_id)] + rbox = [x1, y1, x2, y2, x3, y3, x4, y4] + dt_res = { + 'image_id': cur_image_id, + 'category_id': category_id, + 'bbox': rbox, + 'score': score + } + det_res.append(dt_res) + return det_res + + +def strip_mask(mask): + row = mask[0, 0, :] + col = mask[0, :, 0] + im_h = len(col) - np.count_nonzero(col == -1) + im_w = len(row) - np.count_nonzero(row == -1) + return mask[:, :im_h, :im_w] + + +def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map): + import pycocotools.mask as mask_util + seg_res = [] + k = 0 + for i in range(len(mask_nums)): + cur_image_id = int(image_id[i][0]) + det_nums = mask_nums[i] + mask_i = masks[k:k + det_nums] + mask_i = strip_mask(mask_i) + for j in range(det_nums): + mask = mask_i[j].astype(np.uint8) + score = float(bboxes[k][1]) + label = int(bboxes[k][0]) + k = k + 1 + if label == -1: + continue + cat_id = label_to_cat_id_map[label] + rle = mask_util.encode( + np.array( + mask[:, :, None], order="F", dtype="uint8"))[0] + if six.PY3: + if 'counts' in rle: + rle['counts'] = rle['counts'].decode("utf8") + sg_res = { + 'image_id': cur_image_id, + 'category_id': cat_id, + 'segmentation': rle, + 'score': score + } + seg_res.append(sg_res) + return seg_res + + +def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map): + import pycocotools.mask as mask_util + segm_res = [] + # for each batch + segms = results['segm'].astype(np.uint8) + clsid_labels = results['cate_label'] + clsid_scores = results['cate_score'] + lengths = segms.shape[0] + im_id = int(image_id[0][0]) + if lengths == 0 or segms is None: + return None + # for each sample + for i in range(lengths - 1): + clsid = int(clsid_labels[i]) + catid = num_id_to_cat_id_map[clsid] + score = float(clsid_scores[i]) + mask = segms[i] + segm = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'image_id': im_id, + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res + + +def get_keypoint_res(results, im_id): + anns = [] + preds = results['keypoint'] + for idx in range(im_id.shape[0]): + image_id = im_id[idx].item() + kpts, scores = preds[idx] + for kpt, score in zip(kpts, scores): + kpt = kpt.flatten() + ann = { + 'image_id': image_id, + 'category_id': 1, # XXX hard code + 'keypoints': kpt.tolist(), + 'score': float(score) + } + x = kpt[0::3] + y = kpt[1::3] + x0, x1, y0, y1 = np.min(x).item(), np.max(x).item(), np.min( + y).item(), np.max(y).item() + ann['area'] = (x1 - x0) * (y1 - y0) + ann['bbox'] = [x0, y0, x1 - x0, y1 - y0] + anns.append(ann) + return anns diff --git a/fastdeploy/vision/evaluation/utils/map_utils.py b/fastdeploy/vision/evaluation/utils/map_utils.py new file mode 100644 index 0000000000..12ea43d3c6 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/map_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import os + + +def draw_pr_curve(precision, + recall, + iou=0.5, + out_dir='pr_curve', + file_name='precision_recall_curve.jpg'): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + output_path = os.path.join(out_dir, file_name) + try: + import matplotlib.pyplot as plt + except Exception as e: + # logger.error('Matplotlib not found, plaese install matplotlib.' + # 'for example: `pip install matplotlib`.') + raise e + plt.cla() + plt.figure('P-R Curve') + plt.title('Precision/Recall Curve(IoU={})'.format(iou)) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.grid(True) + plt.plot(recall, precision) + plt.savefig(output_path) diff --git a/fastdeploy/vision/evaluation/utils/metrics.py b/fastdeploy/vision/evaluation/utils/metrics.py new file mode 100644 index 0000000000..ece5036937 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/metrics.py @@ -0,0 +1,89 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import sys +from collections import OrderedDict +from .coco_utils import get_infer_results, cocoapi_eval + + +class COCOMetric(object): + def __init__(self, coco_gt, **kwargs): + self.clsid2catid = { + i: cat['id'] + for i, cat in enumerate(coco_gt.loadCats(coco_gt.getCatIds())) + } + self.coco_gt = coco_gt + self.classwise = kwargs.get('classwise', False) + self.bias = 0 + self.reset() + + def reset(self): + # only bbox and mask evaluation support currently + self.details = { + 'gt': copy.deepcopy(self.coco_gt.dataset), + 'bbox': [], + 'mask': [] + } + self.eval_stats = {} + + def update(self, im_id, outputs): + outs = {} + # outputs Tensor -> numpy.ndarray + for k, v in outputs.items(): + outs[k] = v + + outs['im_id'] = im_id + infer_results = get_infer_results( + outs, self.clsid2catid, bias=self.bias) + self.details['bbox'] += infer_results[ + 'bbox'] if 'bbox' in infer_results else [] + self.details['mask'] += infer_results[ + 'mask'] if 'mask' in infer_results else [] + + def accumulate(self): + if len(self.details['bbox']) > 0: + bbox_stats = cocoapi_eval( + copy.deepcopy(self.details['bbox']), + 'bbox', + coco_gt=self.coco_gt, + classwise=self.classwise) + self.eval_stats['bbox'] = bbox_stats + sys.stdout.flush() + + if len(self.details['mask']) > 0: + seg_stats = cocoapi_eval( + copy.deepcopy(self.details['mask']), + 'segm', + coco_gt=self.coco_gt, + classwise=self.classwise) + self.eval_stats['mask'] = seg_stats + sys.stdout.flush() + + def log(self): + pass + + def get(self): + if 'bbox' not in self.eval_stats: + return {'bbox_mmap': 0.} + if 'mask' in self.eval_stats: + return OrderedDict( + zip(['bbox_mmap', 'segm_mmap'], + [self.eval_stats['bbox'][0], self.eval_stats['mask'][0]])) + else: + return {'bbox_mmap': self.eval_stats['bbox'][0]} diff --git a/fastdeploy/vision/evaluation/utils/util.py b/fastdeploy/vision/evaluation/utils/util.py new file mode 100644 index 0000000000..700ac2cbed --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/util.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import multiprocessing as mp + + +def is_pic(img_name): + valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'] + suffix = img_name.split('.')[-1] + if suffix not in valid_suffix: + return False + return True + + +def get_num_workers(num_workers): + if not platform.system() == 'Linux': + # Dataloader with multi-process model is not supported + # on MacOS and Windows currently. + return 0 + if num_workers == 'auto': + num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 2 else 2 + return num_workers diff --git a/fastdeploy/vision/linzaer/__init__.py b/fastdeploy/vision/linzaer/__init__.py new file mode 100644 index 0000000000..d31e43682c --- /dev/null +++ b/fastdeploy/vision/linzaer/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from ... import FastDeployModel, Frontend +from ... import fastdeploy_main as C + + +class UltraFace(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=Frontend.ONNX): + # 调用基函数进行backend_option的初始化 + # 初始化后的option保存在self._runtime_option + super(UltraFace, self).__init__(runtime_option) + + self._model = C.vision.linzaer.UltraFace( + model_file, params_file, self._runtime_option, model_format) + # 通过self.initialized判断整个模型的初始化是否成功 + assert self.initialized, "UltraFace initialize failed." + + def predict(self, input_image, conf_threshold=0.7, nms_iou_threshold=0.3): + return self._model.predict(input_image, conf_threshold, + nms_iou_threshold) + + # 一些跟UltraFace模型有关的属性封装 + # 多数是预处理相关,可通过修改如model.size = [640, 480]改变预处理时resize的大小(前提是模型支持) + @property + def size(self): + return self._model.size + + @size.setter + def size(self, wh): + assert isinstance(wh, [list, tuple]),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh diff --git a/model_zoo/vision/ultraface/README.md b/model_zoo/vision/ultraface/README.md new file mode 100644 index 0000000000..988f607426 --- /dev/null +++ b/model_zoo/vision/ultraface/README.md @@ -0,0 +1,49 @@ +# UltraFace部署示例 + +当前支持模型版本为:[UltraFace CommitID:dffdddd](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/commit/dffdddd) + +本文档说明如何进行[UltraFace](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/)的快速部署推理。本目录结构如下 + +``` +. +├── cpp # C++ 代码目录 +│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件 +│   ├── README.md # C++ 代码编译部署文档 +│   └── ultraface.cc # C++ 示例代码 +├── api.md # API 说明文档 +├── README.md # UltraFace 部署文档 +└── ultraface.py # Python示例代码 +``` + +## 安装FastDeploy + +使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu` +```bash +# 安装fastdeploy-python工具 +pip install fastdeploy-python + +# 安装vision-cpu模块 +fastdeploy install vision-cpu +``` + +## Python部署 + +执行如下代码即会自动下载YOLOv5Face模型和测试图片 +```bash +python ultraface.py +``` + +执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下 +``` +FaceDetectionResult: [xmin, ymin, xmax, ymax, score] +742.528931,261.309937, 837.749146, 365.145599, 0.999833 +408.159332,253.410889, 484.747284, 353.378052, 0.999832 +549.409424,225.051819, 636.311890, 337.824707, 0.999782 +185.562805,233.364044, 252.001801, 323.948669, 0.999709 +304.065918,180.468140, 377.097961, 278.932861, 0.999645 +``` + +## 其它文档 + +- [C++部署](./cpp/README.md) +- [UltraFace API文档](./api.md) diff --git a/model_zoo/vision/ultraface/api.md b/model_zoo/vision/ultraface/api.md new file mode 100644 index 0000000000..8dc7d2fb71 --- /dev/null +++ b/model_zoo/vision/ultraface/api.md @@ -0,0 +1,71 @@ +# UltraFace API说明 + +## Python API + +### UltraFace类 +``` +fastdeploy.vision.linzaer.UltraFace(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX) +``` +UltraFace模型加载和初始化,当model_format为`fd.Frontend.ONNX`时,只需提供model_file,如`version-RFB-320.onnx`;当model_format为`fd.Frontend.PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### predict函数 +> ``` +> UltraFace.predict(image_data, conf_threshold=0.7, nms_iou_threshold=0.3) +> ``` +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值 +> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 + +示例代码参考[ultraface.py](./ultraface.py) + + +## C++ API + +### UltraFace类 +``` +fastdeploy::vision::linzaer::UltraFace( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` +UltraFace模型加载和初始化,当model_format为`Frontend::ONNX`时,只需提供model_file,如`version-RFB-320.onnx`;当model_format为`Frontend::PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### Predict函数 +> ``` +> UltraFace::Predict(cv::Mat* im, FaceDetectionResult* result, +> float conf_threshold = 0.7, +> float nms_iou_threshold = 0.3) +> ``` +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度 +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值 + +示例代码参考[cpp/ultraface.cc](cpp/ultraface.cc) + +## 其它API使用 + +- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md) diff --git a/model_zoo/vision/ultraface/cpp/CMakeLists.txt b/model_zoo/vision/ultraface/cpp/CMakeLists.txt new file mode 100644 index 0000000000..a33967dee7 --- /dev/null +++ b/model_zoo/vision/ultraface/cpp/CMakeLists.txt @@ -0,0 +1,17 @@ +PROJECT(ultraface_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.16) + +# 在低版本ABI环境中,通过如下代码进行兼容性编译 +# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/) + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(ultraface_demo ${PROJECT_SOURCE_DIR}/ultraface.cc) +# 添加FastDeploy库依赖 +target_link_libraries(ultraface_demo ${FASTDEPLOY_LIBS}) diff --git a/model_zoo/vision/ultraface/cpp/README.md b/model_zoo/vision/ultraface/cpp/README.md new file mode 100644 index 0000000000..d2098d8386 --- /dev/null +++ b/model_zoo/vision/ultraface/cpp/README.md @@ -0,0 +1,36 @@ +# 编译UltraFace示例 + +当前支持模型版本为:[UltraFace CommitID:dffdddd](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/commit/dffdddd) + +## 下载和解压预测库 +```bash +wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz +tar xvf fastdeploy-linux-x64-0.0.3.tgz +``` + +## 编译示例代码 +```bash +mkdir build & cd build +cmake .. +make -j +``` + +## 下载模型和图片 +wget https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/raw/master/models/onnx/version-RFB-320.onnx +wget https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/raw/master/imgs/3.jpg + + +## 执行 +```bash +./ultraface_demo +``` + +执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 +``` +FaceDetectionResult: [xmin, ymin, xmax, ymax, score] +742.528931,261.309937, 837.749146, 365.145599, 0.999833 +408.159332,253.410889, 484.747284, 353.378052, 0.999832 +549.409424,225.051819, 636.311890, 337.824707, 0.999782 +185.562805,233.364044, 252.001801, 323.948669, 0.999709 +304.065918,180.468140, 377.097961, 278.932861, 0.999645 +``` diff --git a/model_zoo/vision/ultraface/cpp/ultraface.cc b/model_zoo/vision/ultraface/cpp/ultraface.cc new file mode 100644 index 0000000000..9f1aa8a9b9 --- /dev/null +++ b/model_zoo/vision/ultraface/cpp/ultraface.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + auto model = vis::linzaer::UltraFace("version-RFB-320.onnx"); + if (!model.Initialized()) { + std::cerr << "Init Failed! Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "Init Done! Model:" << model_file << std::endl; + } + model.EnableDebug(); + + cv::Mat im = cv::imread("3.jpg"); + cv::Mat vis_im = im.clone(); + + vis::FaceDetectionResult res; + if (!model.Predict(&im, &res, 0.7f, 0.3f)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/model_zoo/vision/ultraface/ultraface.py b/model_zoo/vision/ultraface/ultraface.py new file mode 100644 index 0000000000..ceb4c313fa --- /dev/null +++ b/model_zoo/vision/ultraface/ultraface.py @@ -0,0 +1,23 @@ +import fastdeploy as fd +import cv2 + +# 下载模型 +model_url = "https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/raw/master/models/onnx/version-RFB-320.onnx" +test_img_url = "https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/raw/master/imgs/3.jpg" +fd.download(model_url, ".", show_progress=True) +fd.download(test_img_url, ".", show_progress=True) + +# 加载模型 +model = fd.vision.linzaer.UltraFace("version-RFB-320.onnx") + +# 预测图片 +im = cv2.imread("3.jpg") +result = model.predict(im, conf_threshold=0.7, nms_iou_threshold=0.3) + +# 可视化结果 +fd.vision.visualize.vis_face_detection(im, result) +cv2.imwrite("vis_result.jpg", im) + +# 输出预测结果 +print(result) +print(model.runtime_option) diff --git a/model_zoo/vision/yolov5face/cpp/README.md b/model_zoo/vision/yolov5face/cpp/README.md index 4f5788458f..60d46cb878 100644 --- a/model_zoo/vision/yolov5face/cpp/README.md +++ b/model_zoo/vision/yolov5face/cpp/README.md @@ -50,7 +50,7 @@ make -j 执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 ``` -aceDetectionResult: [xmin, ymin, xmax, ymax, score, (x, y) x 5] +FaceDetectionResult: [xmin, ymin, xmax, ymax, score, (x, y) x 5] 749.575256,375.122162, 775.008850, 407.858215, 0.851824, (756.933838,388.423157), (767.810974,387.932922), (762.617065,394.212341), (758.053101,399.073639), (767.370300,398.769470) 897.833862,380.372864, 924.725281, 409.566803, 0.847505, (903.757202,390.221741), (914.575867,389.495911), (908.998901,395.983307), (905.803223,400.871429), (914.674438,400.268066) 281.558197,367.739349, 305.474701, 397.860535, 0.840915, (287.018768,379.771088), (297.285004,378.755280), (292.057831,385.207367), (289.110962,390.010437), (297.535339,389.412048)