diff --git a/.new_docs/cn/faq/add_new_model.md b/.new_docs/cn/faq/add_new_model.md new file mode 100644 index 0000000000..ab5afb9b2f --- /dev/null +++ b/.new_docs/cn/faq/add_new_model.md @@ -0,0 +1,268 @@ +# FastDeploy外部模型集成指引 + +在FastDeploy里面新增一个模型,包括增加C++/Python的部署支持。 本文以torchvision v0.12.0中的ResNet50模型为例,介绍使用FastDeploy做外部[模型集成](#modelsupport),具体包括如下3步。 + +| 步骤 | 说明 | 创建或修改的文件 | +|:------:|:-------------------------------------:|:---------------------------------------------:| +| [1](#step2) | 在fastdeploy/vision相应任务模块增加模型实现 | resnet.h、resnet.cc、vision.h | +| [2](#step4) | 通过pybind完成Python接口绑定 | resnet_pybind.cc、classification_pybind.cc | +| [3](#step5) | 实现Python相应调用接口 | resnet.py、\_\_init\_\_.py | + +在完成上述3步之后,一个外部模型就集成好了。 +
+如果您想为FastDeploy贡献代码,还需要为新增模型添加测试代码、说明文档和代码注释,可在[测试](#test)中查看。 +## 模型集成 + +### 模型准备 + + +在集成外部模型之前,先要将训练好的模型(.pt,.pdparams 等)转换成FastDeploy支持部署的模型格式(.onnx,.pdmodel)。多数开源仓库会提供模型转换脚本,可以直接利用脚本做模型的转换。由于torchvision没有提供转换脚本,因此手动编写转换脚本,本文中将 `torchvison.models.resnet50` 转换为 `resnet50.onnx`, 参考代码如下: + +```python +import torch +import torchvision.models as models +model = models.resnet50(pretrained=True) +batch_size = 1 #批处理大小 +input_shape = (3, 224, 224) #输入数据,改成自己的输入shape +model.eval() +x = torch.randn(batch_size, *input_shape) # 生成张量 +export_onnx_file = "resnet50.onnx" # 目的ONNX文件名 +torch.onnx.export(model, + x, + export_onnx_file, + opset_version=12, + input_names=["input"], # 输入名 + output_names=["output"], # 输出名 + dynamic_axes={"input":{0:"batch_size"}, # 批处理变量 + "output":{0:"batch_size"}}) +``` +执行上述脚本将会得到 `resnet50.onnx` 文件。 + +### C++部分 +* 创建`resnet.h`文件 + * 创建位置 + * FastDeploy/fastdeploy/vision/classification/contrib/resnet.h (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名.h) + * 创建内容 + * 首先在resnet.h中创建 ResNet类并继承FastDeployModel父类,之后声明`Predict`、`Initialize`、`Preprocess`、`Postprocess`和`构造函数`,以及必要的变量,具体的代码细节请参考[resnet.h](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-69128489e918f305c208476ba793d8167e77de2aa7cadf5dcbac30da448bd28e)。 + +```C++ +class FASTDEPLOY_DECL ResNet : public FastDeployModel { + public: + ResNet(...); + virtual bool Predict(...); + private: + bool Initialize(); + bool Preprocess(...); + bool Postprocess(...); +}; +``` + +* 创建`resnet.cc`文件 + * 创建位置 + * FastDeploy/fastdeploy/vision/classification/contrib/resnet.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名.cc) + * 创建内容 + * 在`resnet.cc`中实现`resnet.h`中声明函数的具体逻辑,其中`PreProcess` 和 `PostProcess`需要参考源官方库的前后处理逻辑复现,ResNet每个函数具体逻辑如下,具体的代码请参考[resnet.cc](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-d229d702de28345253a53f2a5839fd2c638f3d32fffa6a7d04d23db9da13a871)。 + +```C++ +ResNet::ResNet(...) { + // 构造函数逻辑 + // 1. 指定 Backend 2. 设置RuntimeOption 3. 调用Initialize()函数 +} +bool ResNet::Initialize() { + // 初始化逻辑 + // 1. 全局变量赋值 2. 调用InitRuntime()函数 + return true; +} +bool ResNet::Preprocess(Mat* mat, FDTensor* output) { +// 前处理逻辑 +// 1. Resize 2. BGR2RGB 3. Normalize 4. HWC2CHW 5. 处理结果存入 FDTensor类中 + return true; +} +bool ResNet::Postprocess(FDTensor& infer_result, ClassifyResult* result, int topk) { + //后处理逻辑 + // 1. Softmax 2. Choose topk labels 3. 结果存入 ClassifyResult类 + return true; +} +bool ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk) { + Preprocess(...) + Infer(...) + Postprocess(...) + return true; +} +``` + +* 在`vision.h`文件中加入新增模型文件 + * 修改位置 + * FastDeploy/fastdeploy/vision.h + * 修改内容 + +```C++ +#ifdef ENABLE_VISION +#include "fastdeploy/vision/classification/contrib/resnet.h" +#endif +``` + + +### Pybind部分 + +* 创建Pybind文件 + * 创建位置 + * FastDeploy/fastdeploy/vision/classification/contrib/resnet_pybind.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名_pybind.cc) + * 创建内容 + * 利用Pybind将C++中的函数变量绑定到Python中,具体代码请参考[resnet_pybind.cc](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-270af0d65720310e2cfbd5373c391b2110d65c0f4efa547f7b7eeffcb958bdec)。 +```C++ +void BindResNet(pybind11::module& m) { + pybind11::class_( + m, "ResNet") + .def(pybind11::init()) + .def("predict", ...) + .def_readwrite("size", &vision::classification::ResNet::size) + .def_readwrite("mean_vals", &vision::classification::ResNet::mean_vals) + .def_readwrite("std_vals", &vision::classification::ResNet::std_vals); +} +``` + +* 调用Pybind函数 + * 修改位置 + * FastDeploy/fastdeploy/vision/classification/classification_pybind.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/任务名称}_pybind.cc) + * 修改内容 +```C++ +void BindResNet(pybind11::module& m); +void BindClassification(pybind11::module& m) { + auto classification_module = + m.def_submodule("classification", "Image classification models."); + BindResNet(classification_module); +} +``` + + +### Python部分 + + +* 创建`resnet.py`文件 + * 创建位置 + * FastDeploy/python/fastdeploy/vision/classification/contrib/resnet.py (FastDeploy/Python代码存放位置/fastdeploy/视觉模型/任务名称/外部模型/模型名.py) + * 创建内容 + * 创建ResNet类继承自FastDeployModel,实现 `\_\_init\_\_`、Pybind绑定的函数(如`predict()`)、以及`对Pybind绑定的全局变量进行赋值和获取的函数`,具体代码请参考[resnet.py](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-a4dc5ec2d450e91f1c03819bf314c238b37ac678df56d7dea3aab7feac10a157)。 + +```python + +class ResNet(FastDeployModel): + def __init__(self, ...): + self._model = C.vision.classification.ResNet(...) + def predict(self, input_image, topk=1): + return self._model.predict(input_image, topk) + @property + def size(self): + return self._model.size + @size.setter + def size(self, wh): + ... +``` + +* 导入ResNet类 + * 修改位置 + * FastDeploy/python/fastdeploy/vision/classification/\_\_init\_\_.py (FastDeploy/Python代码存放位置/fastdeploy/视觉模型/任务名称/\_\_init\_\_.py) + * 修改内容 + +```Python +from .contrib.resnet import ResNet +``` + +## 测试 +### 编译 + * C++ + * 位置:FastDeploy/ + +``` +mkdir build & cd build +cmake .. -DENABLE_ORT_BACKEND=ON -DENABLE_VISION=ON -DCMAKE_INSTALL_PREFIX=${PWD/fastdeploy-0.0.3 +-DENABLE_PADDLE_BACKEND=ON -DENABLE_TRT_BACKEND=ON -DWITH_GPU=ON -DTRT_DIRECTORY=/PATH/TO/TensorRT/ +make -j8 +make install +``` + + 编译会得到 build/fastdeploy-0.0.3/。 + + * Python + * 位置:FastDeploy/python/ + +``` +export TRT_DIRECTORY=/PATH/TO/TensorRT/ # 如果用TensorRT 需要填写TensorRT所在位置,并开启 ENABLE_TRT_BACKEND +export ENABLE_TRT_BACKEND=ON +export WITH_GPU=ON +export ENABLE_PADDLE_BACKEND=ON +export ENABLE_OPENVINO_BACKEND=ON +export ENABLE_VISION=ON +export ENABLE_ORT_BACKEND=ON +python setup.py build +python setup.py bdist_wheel +cd dist +pip install fastdeploy_gpu_python-版本号-cpxx-cpxxm-系统架构.whl +``` + +### 编写测试代码 + * 创建位置: FastDeploy/examples/vision/classification/resnet/ (FastDeploy/示例目录/视觉模型/任务名称/模型名/) + * 创建目录结构 + +``` +. +├── cpp +│ ├── CMakeLists.txt +│ ├── infer.cc // C++ 版本测试代码 +│ └── README.md // C++版本使用文档 +├── python +│ ├── infer.py // Python 版本测试代码 +│ └── README.md // Python版本使用文档 +└── README.md // ResNet 模型集成说明文档 +``` + +* C++ + * 编写CmakeLists文件、C++ 代码以及 README.md 内容请参考[cpp/](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-afcbe607b796509581f89e38b84190717f1eeda2df0419a2ac9034197ead5f96)。 + * 编译 infer.cc + * 位置:FastDeploy/examples/vision/classification/resnet/cpp/ + +``` +mkdir build & cd build +cmake .. -DFASTDEPLOY_INSTALL_DIR=/PATH/TO/FastDeploy/build/fastdeploy-0.0.3/ +make +``` + +* Python + * Python 代码以及 README.md 内容请参考[python/](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-5a0d6be8c603a8b81454ac14c17fb93555288d9adf92bbe40454449309700135)。 + +### 为代码添加注释 +为了方便用户理解代码,我们需要为新增代码添加注释,添加注释方法可参考如下示例。 +- C++ 代码 +您需要在resnet.h文件中为函数和变量增加注释,有如下三种注释方式,具体可参考[resnet.h](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-69128489e918f305c208476ba793d8167e77de2aa7cadf5dcbac30da448bd28e)。 + +```C++ +/** \brief Predict for the input "im", the result will be saved in "result". +* +* \param[in] im Input image for inference. +* \param[in] result Saving the inference result. +* \param[in] topk The length of return values, e.g., if topk==2, the result will include the 2 most possible class label for input image. +*/ +virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1); + +/// Tuple of (width, height) +std::vector size; +/*! @brief Initialize for ResNet model, assign values to the global variables and call InitRuntime() +*/ +bool Initialize(); +``` +- Python 代码 +你需要为resnet.py文件中的函数和变量增加适当的注释,示例如下,具体可参考[resnet.py](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-a4dc5ec2d450e91f1c03819bf314c238b37ac678df56d7dea3aab7feac10a157)。 + +```python + def predict(self, input_image, topk=1): + """Classify an input image + + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :param topk: (int)The topk result by the classify confidence score, default 1 + :return: ClassifyResult + """ + return self._model.predict(input_image, topk) +``` + +对于集成模型过程中的其他文件,您也可以对实现的细节添加适当的注释说明。 diff --git a/docs/api_docs/python/image_classification.md b/docs/api_docs/python/image_classification.md index 46760eec3a..ad284ca66d 100755 --- a/docs/api_docs/python/image_classification.md +++ b/docs/api_docs/python/image_classification.md @@ -15,3 +15,11 @@ :members: :inherited-members: ``` + +## fastdeploy.vision.classification.ResNet + +```{eval-rst} +.. autoclass:: fastdeploy.vision.classification.ResNet + :members: + :inherited-members: +``` diff --git a/examples/vision/classification/resnet/README.md b/examples/vision/classification/resnet/README.md new file mode 100644 index 0000000000..dd20562fb5 --- /dev/null +++ b/examples/vision/classification/resnet/README.md @@ -0,0 +1,53 @@ +# ResNet准备部署模型 + +- ResNet部署实现来自[Torchvision](https://github.com/pytorch/vision/tree/v0.12.0)的代码,和[基于ImageNet2012的预训练模型](https://github.com/pytorch/vision/tree/v0.12.0)。 + + - (1)[官方库](https://github.com/pytorch/vision/tree/v0.12.0)提供的*.pt通过[导出ONNX模型](#导出ONNX模型)操作后,可进行部署; + - (2)自己数据训练的ResNet模型,按照[导出ONNX模型](#%E5%AF%BC%E5%87%BAONNX%E6%A8%A1%E5%9E%8B)操作后,参考[详细部署文档](#详细部署文档)完成部署。 + + +## 导出ONNX模型 + + + 导入[Torchvision](https://github.com/pytorch/vision/tree/v0.12.0),加载预训练模型,并进行模型转换,具体转换步骤如下。 + + ```python + import torch + import torchvision.models as models + + model = models.resnet50(pretrained=True) + batch_size = 1 #批处理大小 + input_shape = (3, 224, 224) #输入数据,改成自己的输入shape + # #set the model to inference mode + model.eval() + x = torch.randn(batch_size, *input_shape) # 生成张量 + export_onnx_file = "ResNet50.onnx" # 目的ONNX文件名 + torch.onnx.export(model, + x, + export_onnx_file, + opset_version=12, + input_names=["input"], # 输入名 + output_names=["output"], # 输出名 + dynamic_axes={"input":{0:"batch_size"}, # 批处理变量 + "output":{0:"batch_size"}}) + ``` + +## 下载预训练ONNX模型 + +为了方便开发者的测试,下面提供了ResNet导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库) +| 模型 | 大小 | 精度 | +|:---------------------------------------------------------------- |:----- |:----- | +| [ResNet-18](https://bj.bcebos.com/paddlehub/fastdeploy/resnet18.onnx) | 45MB | | +| [ResNet-34](https://bj.bcebos.com/paddlehub/fastdeploy/resnet34.onnx) | 84MB | | +| [ResNet-50](https://bj.bcebos.com/paddlehub/fastdeploy/resnet50.onnx) | 98MB | | +| [ResNet-101](https://bj.bcebos.com/paddlehub/fastdeploy/resnet101.onnx) | 170MB | | + + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) + +## 版本说明 + +- 本版本文档和代码基于[Torchvision v0.12.0](https://github.com/pytorch/vision/tree/v0.12.0) 编写 diff --git a/examples/vision/classification/resnet/cpp/CMakeLists.txt b/examples/vision/classification/resnet/cpp/CMakeLists.txt new file mode 100644 index 0000000000..93540a7e83 --- /dev/null +++ b/examples/vision/classification/resnet/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# 添加FastDeploy库依赖 +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/classification/resnet/cpp/README.md b/examples/vision/classification/resnet/cpp/README.md new file mode 100644 index 0000000000..eb3bff6f48 --- /dev/null +++ b/examples/vision/classification/resnet/cpp/README.md @@ -0,0 +1,77 @@ +# ResNet C++部署示例 + +本目录下提供`infer.cc`快速完成ResNet系列模型在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start) + +以Linux上 ResNet50 推理为例,在本目录执行如下命令即可完成编译测试 + +```bash +#下载SDK,编译模型examples代码(SDK中包含了examples代码) +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.2.1.tgz +tar xvf fastdeploy-linux-x64-gpu-0.2.1.tgz +cd fastdeploy-linux-x64-gpu-0.2.1/examples/vision/classification/resnet/cpp +mkdir build +cd build +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.2.1 +make -j + +# 下载ResNet模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/resnet50.onnx +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + + +# CPU推理 +./infer_demo resnet50.onnx ILSVRC2012_val_00000010.jpeg 0 +# GPU推理 +./infer_demo resnet50.onnx ILSVRC2012_val_00000010.jpeg 1 +# GPU上TensorRT推理 +./infer_demo resnet50.onnx ILSVRC2012_val_00000010.jpeg 2 +``` + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md) + +## ResNet C++接口 + +### ResNet类 + +```c++ + +fastdeploy::vision::classification::ResNet( + const std::string& model_file, + const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX) +``` + + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式 + +#### Predict函数 + +> ```c++ +> ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk = 1) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 分类结果,包括label_id,以及相应的置信度, ClassifyResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) +> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1 + + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md) diff --git a/examples/vision/classification/resnet/cpp/infer.cc b/examples/vision/classification/resnet/cpp/infer.cc new file mode 100644 index 0000000000..083c9de60a --- /dev/null +++ b/examples/vision/classification/resnet/cpp/infer.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +void CpuInfer(const std::string& model_file, const std::string& image_file) { + auto model = fastdeploy::vision::classification::ResNet(model_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + +} + +void GpuInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::classification::ResNet(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; +} + +void TrtInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("images", {1, 3, 224, 224}); + auto model = fastdeploy::vision::classification::ResNet(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: infer_demo path/to/model path/to/image run_option, " + "e.g ./infer_model ./resnet50.onnx ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/classification/resnet/python/README.md b/examples/vision/classification/resnet/python/README.md new file mode 100644 index 0000000000..6315ee06a6 --- /dev/null +++ b/examples/vision/classification/resnet/python/README.md @@ -0,0 +1,72 @@ +# ResNet模型 Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start) + +本目录下提供`infer.py`快速完成ResNet50_vd在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/classification/resnet/python + +# 下载ResNet50_vd模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/resnet50.onnx +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + +# CPU推理 +python infer.py --model resnet50.onnx --image ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 +# GPU推理 +python infer.py --model resnet50.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 +# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待) +python infer.py --model resnet50.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 +``` + +运行完成后返回结果如下所示 +```bash +ClassifyResult( +label_ids: 332, +scores: 0.825349, +) +``` + +## ResNet Python接口 + +```python +fd.vision.classification.ResNet(model_file, params_file, runtime_option=None, model_format=ModelFormat.ONNX) +``` + + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式 + +### predict函数 + +> ```python +> ResNet.predict(input_image, topk=1) +> ``` +> +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1 + +> **返回** +> +> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + + +## 其它文档 + +- [ResNet 模型介绍](..) +- [ResNet C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md) diff --git a/examples/vision/classification/resnet/python/infer.py b/examples/vision/classification/resnet/python/infer.py new file mode 100644 index 0000000000..b8b268f3ab --- /dev/null +++ b/examples/vision/classification/resnet/python/infer.py @@ -0,0 +1,50 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="Path of PaddleClas model.") + parser.add_argument( + "--image", type=str, required=True, help="Path of test image file.") + parser.add_argument( + "--topk", type=int, default=1, help="Return topk results.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + return option + + +args = parse_arguments() + +# 配置runtime,加载模型 +runtime_option = build_option(args) + +model = fd.vision.classification.ResNet( + args.model, runtime_option=runtime_option) +# 预测图片分类结果 +im = cv2.imread(args.image) +result = model.predict(im.copy(), args.topk) +print(result) diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 51fc393e5b..b83fb0f3da 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -17,6 +17,7 @@ #ifdef ENABLE_VISION #include "fastdeploy/vision/classification/contrib/yolov5cls.h" #include "fastdeploy/vision/classification/ppcls/model.h" +#include "fastdeploy/vision/classification/contrib/resnet.h" #include "fastdeploy/vision/detection/contrib/nanodet_plus.h" #include "fastdeploy/vision/detection/contrib/scaledyolov4.h" #include "fastdeploy/vision/detection/contrib/yolor.h" diff --git a/fastdeploy/vision/classification/classification_pybind.cc b/fastdeploy/vision/classification/classification_pybind.cc index 497d692c3c..cae130de5e 100644 --- a/fastdeploy/vision/classification/classification_pybind.cc +++ b/fastdeploy/vision/classification/classification_pybind.cc @@ -18,11 +18,12 @@ namespace fastdeploy { void BindYOLOv5Cls(pybind11::module& m); void BindPaddleClas(pybind11::module& m); - +void BindResNet(pybind11::module& m); void BindClassification(pybind11::module& m) { auto classification_module = m.def_submodule("classification", "Image classification models."); BindYOLOv5Cls(classification_module); BindPaddleClas(classification_module); + BindResNet(classification_module); } } // namespace fastdeploy diff --git a/fastdeploy/vision/classification/contrib/resnet.cc b/fastdeploy/vision/classification/contrib/resnet.cc new file mode 100644 index 0000000000..8050ceccbf --- /dev/null +++ b/fastdeploy/vision/classification/contrib/resnet.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/classification/contrib/resnet.h" +#include "fastdeploy/vision/utils/utils.h" +#include "fastdeploy/utils/perf.h" + +namespace fastdeploy { +namespace vision { +namespace classification { + +ResNet::ResNet(const std::string& model_file, + const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + // In constructor, the 3 steps below are necessary. + // 1. set the Backend 2. set RuntimeOption 3. call Initialize() + + if (model_format == ModelFormat::ONNX) { + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + } else { + valid_cpu_backends = {Backend::PDINFER}; + valid_gpu_backends = {Backend::PDINFER}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool ResNet::Initialize() { + + // In this function, the 3 steps below are necessary. + // 1. assign values to the global variables 2. call InitRuntime() + + size = {224, 224}; + mean_vals = {0.485f, 0.456f, 0.406f}; + std_vals = {0.229f, 0.224f, 0.225f}; + + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + + +bool ResNet::Preprocess(Mat* mat, FDTensor* output) { + +// In this function, the preprocess need be implemented according to the original Repos, +// The result of preprocess has to be saved in FDTensor variable, because the input of Infer() need to be std::vector. +// 1. Resize 2. BGR2RGB 3. Normalize 4. HWC2CHW 5. Put the result into FDTensor variable. + + if (mat->Height()!=size[0] || mat->Width()!=size[1]){ + int interp = cv::INTER_LINEAR; + Resize::Run(mat, size[1], size[0], -1, -1, interp); + } + + BGR2RGB::Run(mat); + Normalize::Run(mat, mean_vals, std_vals); + + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +bool ResNet::Postprocess(FDTensor& infer_result, + ClassifyResult* result, int topk) { + + // In this function, the postprocess need be implemented according to the original Repos, + // Finally the reslut of postprocess should be saved in ClassifyResult variable. + // 1. Softmax 2. Choose topk labels 3. Put the result into ClassifyResult variable. + + int num_classes = infer_result.shape[1]; + Softmax(infer_result, &infer_result); + const float* infer_result_buffer = reinterpret_cast(infer_result.Data()); + topk = std::min(num_classes, topk); + result->label_ids = + utils::TopKIndices(infer_result_buffer, num_classes, topk); + result->scores.resize(topk); + for (int i = 0; i < topk; ++i) { + result->scores[i] = *(infer_result_buffer + result->label_ids[i]); + } + return true; +} + +bool ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk) { + + // In this function, the Preprocess(), Infer(), and Postprocess() are called sequentially. + + Mat mat(*im); + std::vector processed_data(1); + if (!Preprocess(&mat, &(processed_data[0]))) { + FDERROR << "Failed to preprocess input data while using model:" + << ModelName() << "." << std::endl; + return false; + } + processed_data[0].name = InputInfoOfRuntime(0).name; + + std::vector output_tensors; + if (!Infer(processed_data, &output_tensors)) { + FDERROR << "Failed to inference while using model:" << ModelName() << "." + << std::endl; + return false; + } + + if (!Postprocess(output_tensors[0], result, topk)) { + FDERROR << "Failed to postprocess while using model:" << ModelName() << "." + << std::endl; + return false; + } + + return true; +} + + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/classification/contrib/resnet.h b/fastdeploy/vision/classification/contrib/resnet.h new file mode 100644 index 0000000000..f766f1bf50 --- /dev/null +++ b/fastdeploy/vision/classification/contrib/resnet.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +// The namespace shoulde be +// fastdeploy::vision::classification (fastdeploy::vision::${task}) +namespace fastdeploy { +namespace vision { +/** \brief All object classification model APIs are defined inside this namespace + * + */ +namespace classification { +/*! @brief ResNet series model + */ +class FASTDEPLOY_DECL ResNet : public FastDeployModel { + public: + /** \brief Set path of model file and the configuration of runtime. + * + * \param[in] model_file Path of model file, e.g ./resnet50.onnx + * \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends" + * \param[in] model_format Model format of the loaded model, default is ONNX format + */ + ResNet(const std::string& model_file, + const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX); + + virtual std::string ModelName() const { return "ResNet"; } + /** \brief Predict for the input "im", the result will be saved in "result". + * + * \param[in] im Input image for inference. + * \param[in] result Saving the inference result. + * \param[in] topk The length of return values, e.g., if topk==2, the result will include the 2 most possible class label for input image. + */ + virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1); + + /// Tuple of (width, height) + std::vector size; + /// Mean parameters for normalize + std::vector mean_vals; + /// Std parameters for normalize + std::vector std_vals; + + + private: + /*! @brief Initialize for ResNet model, assign values to the global variables and call InitRuntime() + */ + bool Initialize(); + /// PreProcessing for the input "mat", the result will be saved in "outputs". + bool Preprocess(Mat* mat, FDTensor* outputs); + /*! @brief PostProcessing for the input "infer_result", the result will be saved in "result". + */ + bool Postprocess(FDTensor& infer_result, ClassifyResult* result, + int topk = 1); +}; +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/classification/contrib/resnet_pybind.cc b/fastdeploy/vision/classification/contrib/resnet_pybind.cc new file mode 100644 index 0000000000..7654d5bdca --- /dev/null +++ b/fastdeploy/vision/classification/contrib/resnet_pybind.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" +// namespace should be `fastdeploy` +namespace fastdeploy { +// the name of Pybind function should be Bind${model_name} +void BindResNet(pybind11::module& m) { + // the constructor and the predict funtion are necessary + // the constructor is used to initialize the python model class. + // the necessary public functions and variables like `size`, `mean_vals` should also be binded. + pybind11::class_( + m, "ResNet") + .def(pybind11::init()) + .def("predict", + [](vision::classification::ResNet& self, pybind11::array& data, + int topk = 1) { + auto mat = PyArrayToCvMat(data); + vision::ClassifyResult res; + self.Predict(&mat, &res, topk); + return res; + }) + .def_readwrite("size", &vision::classification::ResNet::size) + .def_readwrite("mean_vals", &vision::classification::ResNet::mean_vals) + .def_readwrite("std_vals", &vision::classification::ResNet::std_vals); +} +} // namespace fastdeploy + diff --git a/python/fastdeploy/vision/classification/__init__.py b/python/fastdeploy/vision/classification/__init__.py index ceeaa024a7..0b426fab13 100644 --- a/python/fastdeploy/vision/classification/__init__.py +++ b/python/fastdeploy/vision/classification/__init__.py @@ -15,7 +15,7 @@ from .contrib.yolov5cls import YOLOv5Cls from .ppcls import PaddleClasModel - +from .contrib.resnet import ResNet PPLCNet = PaddleClasModel PPLCNetv2 = PaddleClasModel EfficientNet = PaddleClasModel diff --git a/python/fastdeploy/vision/classification/contrib/resnet.py b/python/fastdeploy/vision/classification/contrib/resnet.py new file mode 100644 index 0000000000..52f45933ba --- /dev/null +++ b/python/fastdeploy/vision/classification/contrib/resnet.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class ResNet(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=ModelFormat.ONNX): + """Load a image classification model exported by ResNet. + + :param model_file: (str)Path of model file, e.g resnet/resnet50.onnx + :param params_file: (str)Path of parameters file, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model, default is ONNX + """ + + # call super() to initialize the backend_option + # the result of initialization will be saved in self._runtime_option + super(ResNet, self).__init__(runtime_option) + + self._model = C.vision.classification.ResNet( + model_file, params_file, self._runtime_option, model_format) + # self.initialized shows the initialization of the model is successful or not + + assert self.initialized, "ResNet initialize failed." + + # Predict and return the inference result of "input_image". + def predict(self, input_image, topk=1): + """Classify an input image + + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :param topk: (int)The topk result by the classify confidence score, default 1 + :return: ClassifyResult + """ + return self._model.predict(input_image, topk) + + # Implement the setter and getter method for variables + @property + def size(self): + """ + Returns the preprocess image size + """ + return self._model.size + + @property + def mean_vals(self): + """ + Returns the mean value of normlization + """ + return self._model.mean_vals + + @property + def std_vals(self): + """ + Returns the std value of normlization + """ + return self._model.std_vals + + @size.setter + def size(self, wh): + assert isinstance(wh, (list, tuple)),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh + + @mean_vals.setter + def mean_vals(self, value): + assert isinstance( + value, list), "The value to set `mean_vals` must be type of list." + self._model.mean_vals = value + + @std_vals.setter + def std_vals(self, value): + assert isinstance( + value, list), "The value to set `std_vals` must be type of list." + self._model.std_vals = value