diff --git a/.new_docs/cn/faq/add_new_model.md b/.new_docs/cn/faq/add_new_model.md
new file mode 100644
index 0000000000..ab5afb9b2f
--- /dev/null
+++ b/.new_docs/cn/faq/add_new_model.md
@@ -0,0 +1,268 @@
+# FastDeploy外部模型集成指引
+
+在FastDeploy里面新增一个模型,包括增加C++/Python的部署支持。 本文以torchvision v0.12.0中的ResNet50模型为例,介绍使用FastDeploy做外部[模型集成](#modelsupport),具体包括如下3步。
+
+| 步骤 | 说明 | 创建或修改的文件 |
+|:------:|:-------------------------------------:|:---------------------------------------------:|
+| [1](#step2) | 在fastdeploy/vision相应任务模块增加模型实现 | resnet.h、resnet.cc、vision.h |
+| [2](#step4) | 通过pybind完成Python接口绑定 | resnet_pybind.cc、classification_pybind.cc |
+| [3](#step5) | 实现Python相应调用接口 | resnet.py、\_\_init\_\_.py |
+
+在完成上述3步之后,一个外部模型就集成好了。
+
+如果您想为FastDeploy贡献代码,还需要为新增模型添加测试代码、说明文档和代码注释,可在[测试](#test)中查看。
+## 模型集成
+
+### 模型准备
+
+
+在集成外部模型之前,先要将训练好的模型(.pt,.pdparams 等)转换成FastDeploy支持部署的模型格式(.onnx,.pdmodel)。多数开源仓库会提供模型转换脚本,可以直接利用脚本做模型的转换。由于torchvision没有提供转换脚本,因此手动编写转换脚本,本文中将 `torchvison.models.resnet50` 转换为 `resnet50.onnx`, 参考代码如下:
+
+```python
+import torch
+import torchvision.models as models
+model = models.resnet50(pretrained=True)
+batch_size = 1 #批处理大小
+input_shape = (3, 224, 224) #输入数据,改成自己的输入shape
+model.eval()
+x = torch.randn(batch_size, *input_shape) # 生成张量
+export_onnx_file = "resnet50.onnx" # 目的ONNX文件名
+torch.onnx.export(model,
+ x,
+ export_onnx_file,
+ opset_version=12,
+ input_names=["input"], # 输入名
+ output_names=["output"], # 输出名
+ dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
+ "output":{0:"batch_size"}})
+```
+执行上述脚本将会得到 `resnet50.onnx` 文件。
+
+### C++部分
+* 创建`resnet.h`文件
+ * 创建位置
+ * FastDeploy/fastdeploy/vision/classification/contrib/resnet.h (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名.h)
+ * 创建内容
+ * 首先在resnet.h中创建 ResNet类并继承FastDeployModel父类,之后声明`Predict`、`Initialize`、`Preprocess`、`Postprocess`和`构造函数`,以及必要的变量,具体的代码细节请参考[resnet.h](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-69128489e918f305c208476ba793d8167e77de2aa7cadf5dcbac30da448bd28e)。
+
+```C++
+class FASTDEPLOY_DECL ResNet : public FastDeployModel {
+ public:
+ ResNet(...);
+ virtual bool Predict(...);
+ private:
+ bool Initialize();
+ bool Preprocess(...);
+ bool Postprocess(...);
+};
+```
+
+* 创建`resnet.cc`文件
+ * 创建位置
+ * FastDeploy/fastdeploy/vision/classification/contrib/resnet.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名.cc)
+ * 创建内容
+ * 在`resnet.cc`中实现`resnet.h`中声明函数的具体逻辑,其中`PreProcess` 和 `PostProcess`需要参考源官方库的前后处理逻辑复现,ResNet每个函数具体逻辑如下,具体的代码请参考[resnet.cc](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-d229d702de28345253a53f2a5839fd2c638f3d32fffa6a7d04d23db9da13a871)。
+
+```C++
+ResNet::ResNet(...) {
+ // 构造函数逻辑
+ // 1. 指定 Backend 2. 设置RuntimeOption 3. 调用Initialize()函数
+}
+bool ResNet::Initialize() {
+ // 初始化逻辑
+ // 1. 全局变量赋值 2. 调用InitRuntime()函数
+ return true;
+}
+bool ResNet::Preprocess(Mat* mat, FDTensor* output) {
+// 前处理逻辑
+// 1. Resize 2. BGR2RGB 3. Normalize 4. HWC2CHW 5. 处理结果存入 FDTensor类中
+ return true;
+}
+bool ResNet::Postprocess(FDTensor& infer_result, ClassifyResult* result, int topk) {
+ //后处理逻辑
+ // 1. Softmax 2. Choose topk labels 3. 结果存入 ClassifyResult类
+ return true;
+}
+bool ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
+ Preprocess(...)
+ Infer(...)
+ Postprocess(...)
+ return true;
+}
+```
+
+* 在`vision.h`文件中加入新增模型文件
+ * 修改位置
+ * FastDeploy/fastdeploy/vision.h
+ * 修改内容
+
+```C++
+#ifdef ENABLE_VISION
+#include "fastdeploy/vision/classification/contrib/resnet.h"
+#endif
+```
+
+
+### Pybind部分
+
+* 创建Pybind文件
+ * 创建位置
+ * FastDeploy/fastdeploy/vision/classification/contrib/resnet_pybind.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名_pybind.cc)
+ * 创建内容
+ * 利用Pybind将C++中的函数变量绑定到Python中,具体代码请参考[resnet_pybind.cc](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-270af0d65720310e2cfbd5373c391b2110d65c0f4efa547f7b7eeffcb958bdec)。
+```C++
+void BindResNet(pybind11::module& m) {
+ pybind11::class_(
+ m, "ResNet")
+ .def(pybind11::init())
+ .def("predict", ...)
+ .def_readwrite("size", &vision::classification::ResNet::size)
+ .def_readwrite("mean_vals", &vision::classification::ResNet::mean_vals)
+ .def_readwrite("std_vals", &vision::classification::ResNet::std_vals);
+}
+```
+
+* 调用Pybind函数
+ * 修改位置
+ * FastDeploy/fastdeploy/vision/classification/classification_pybind.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/任务名称}_pybind.cc)
+ * 修改内容
+```C++
+void BindResNet(pybind11::module& m);
+void BindClassification(pybind11::module& m) {
+ auto classification_module =
+ m.def_submodule("classification", "Image classification models.");
+ BindResNet(classification_module);
+}
+```
+
+
+### Python部分
+
+
+* 创建`resnet.py`文件
+ * 创建位置
+ * FastDeploy/python/fastdeploy/vision/classification/contrib/resnet.py (FastDeploy/Python代码存放位置/fastdeploy/视觉模型/任务名称/外部模型/模型名.py)
+ * 创建内容
+ * 创建ResNet类继承自FastDeployModel,实现 `\_\_init\_\_`、Pybind绑定的函数(如`predict()`)、以及`对Pybind绑定的全局变量进行赋值和获取的函数`,具体代码请参考[resnet.py](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-a4dc5ec2d450e91f1c03819bf314c238b37ac678df56d7dea3aab7feac10a157)。
+
+```python
+
+class ResNet(FastDeployModel):
+ def __init__(self, ...):
+ self._model = C.vision.classification.ResNet(...)
+ def predict(self, input_image, topk=1):
+ return self._model.predict(input_image, topk)
+ @property
+ def size(self):
+ return self._model.size
+ @size.setter
+ def size(self, wh):
+ ...
+```
+
+* 导入ResNet类
+ * 修改位置
+ * FastDeploy/python/fastdeploy/vision/classification/\_\_init\_\_.py (FastDeploy/Python代码存放位置/fastdeploy/视觉模型/任务名称/\_\_init\_\_.py)
+ * 修改内容
+
+```Python
+from .contrib.resnet import ResNet
+```
+
+## 测试
+### 编译
+ * C++
+ * 位置:FastDeploy/
+
+```
+mkdir build & cd build
+cmake .. -DENABLE_ORT_BACKEND=ON -DENABLE_VISION=ON -DCMAKE_INSTALL_PREFIX=${PWD/fastdeploy-0.0.3
+-DENABLE_PADDLE_BACKEND=ON -DENABLE_TRT_BACKEND=ON -DWITH_GPU=ON -DTRT_DIRECTORY=/PATH/TO/TensorRT/
+make -j8
+make install
+```
+
+ 编译会得到 build/fastdeploy-0.0.3/。
+
+ * Python
+ * 位置:FastDeploy/python/
+
+```
+export TRT_DIRECTORY=/PATH/TO/TensorRT/ # 如果用TensorRT 需要填写TensorRT所在位置,并开启 ENABLE_TRT_BACKEND
+export ENABLE_TRT_BACKEND=ON
+export WITH_GPU=ON
+export ENABLE_PADDLE_BACKEND=ON
+export ENABLE_OPENVINO_BACKEND=ON
+export ENABLE_VISION=ON
+export ENABLE_ORT_BACKEND=ON
+python setup.py build
+python setup.py bdist_wheel
+cd dist
+pip install fastdeploy_gpu_python-版本号-cpxx-cpxxm-系统架构.whl
+```
+
+### 编写测试代码
+ * 创建位置: FastDeploy/examples/vision/classification/resnet/ (FastDeploy/示例目录/视觉模型/任务名称/模型名/)
+ * 创建目录结构
+
+```
+.
+├── cpp
+│ ├── CMakeLists.txt
+│ ├── infer.cc // C++ 版本测试代码
+│ └── README.md // C++版本使用文档
+├── python
+│ ├── infer.py // Python 版本测试代码
+│ └── README.md // Python版本使用文档
+└── README.md // ResNet 模型集成说明文档
+```
+
+* C++
+ * 编写CmakeLists文件、C++ 代码以及 README.md 内容请参考[cpp/](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-afcbe607b796509581f89e38b84190717f1eeda2df0419a2ac9034197ead5f96)。
+ * 编译 infer.cc
+ * 位置:FastDeploy/examples/vision/classification/resnet/cpp/
+
+```
+mkdir build & cd build
+cmake .. -DFASTDEPLOY_INSTALL_DIR=/PATH/TO/FastDeploy/build/fastdeploy-0.0.3/
+make
+```
+
+* Python
+ * Python 代码以及 README.md 内容请参考[python/](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-5a0d6be8c603a8b81454ac14c17fb93555288d9adf92bbe40454449309700135)。
+
+### 为代码添加注释
+为了方便用户理解代码,我们需要为新增代码添加注释,添加注释方法可参考如下示例。
+- C++ 代码
+您需要在resnet.h文件中为函数和变量增加注释,有如下三种注释方式,具体可参考[resnet.h](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-69128489e918f305c208476ba793d8167e77de2aa7cadf5dcbac30da448bd28e)。
+
+```C++
+/** \brief Predict for the input "im", the result will be saved in "result".
+*
+* \param[in] im Input image for inference.
+* \param[in] result Saving the inference result.
+* \param[in] topk The length of return values, e.g., if topk==2, the result will include the 2 most possible class label for input image.
+*/
+virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1);
+
+/// Tuple of (width, height)
+std::vector size;
+/*! @brief Initialize for ResNet model, assign values to the global variables and call InitRuntime()
+*/
+bool Initialize();
+```
+- Python 代码
+你需要为resnet.py文件中的函数和变量增加适当的注释,示例如下,具体可参考[resnet.py](https://github.com/PaddlePaddle/FastDeploy/pull/347/files#diff-a4dc5ec2d450e91f1c03819bf314c238b37ac678df56d7dea3aab7feac10a157)。
+
+```python
+ def predict(self, input_image, topk=1):
+ """Classify an input image
+
+ :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
+ :param topk: (int)The topk result by the classify confidence score, default 1
+ :return: ClassifyResult
+ """
+ return self._model.predict(input_image, topk)
+```
+
+对于集成模型过程中的其他文件,您也可以对实现的细节添加适当的注释说明。
diff --git a/docs/api_docs/python/image_classification.md b/docs/api_docs/python/image_classification.md
index 46760eec3a..ad284ca66d 100755
--- a/docs/api_docs/python/image_classification.md
+++ b/docs/api_docs/python/image_classification.md
@@ -15,3 +15,11 @@
:members:
:inherited-members:
```
+
+## fastdeploy.vision.classification.ResNet
+
+```{eval-rst}
+.. autoclass:: fastdeploy.vision.classification.ResNet
+ :members:
+ :inherited-members:
+```
diff --git a/examples/vision/classification/resnet/README.md b/examples/vision/classification/resnet/README.md
new file mode 100644
index 0000000000..dd20562fb5
--- /dev/null
+++ b/examples/vision/classification/resnet/README.md
@@ -0,0 +1,53 @@
+# ResNet准备部署模型
+
+- ResNet部署实现来自[Torchvision](https://github.com/pytorch/vision/tree/v0.12.0)的代码,和[基于ImageNet2012的预训练模型](https://github.com/pytorch/vision/tree/v0.12.0)。
+
+ - (1)[官方库](https://github.com/pytorch/vision/tree/v0.12.0)提供的*.pt通过[导出ONNX模型](#导出ONNX模型)操作后,可进行部署;
+ - (2)自己数据训练的ResNet模型,按照[导出ONNX模型](#%E5%AF%BC%E5%87%BAONNX%E6%A8%A1%E5%9E%8B)操作后,参考[详细部署文档](#详细部署文档)完成部署。
+
+
+## 导出ONNX模型
+
+
+ 导入[Torchvision](https://github.com/pytorch/vision/tree/v0.12.0),加载预训练模型,并进行模型转换,具体转换步骤如下。
+
+ ```python
+ import torch
+ import torchvision.models as models
+
+ model = models.resnet50(pretrained=True)
+ batch_size = 1 #批处理大小
+ input_shape = (3, 224, 224) #输入数据,改成自己的输入shape
+ # #set the model to inference mode
+ model.eval()
+ x = torch.randn(batch_size, *input_shape) # 生成张量
+ export_onnx_file = "ResNet50.onnx" # 目的ONNX文件名
+ torch.onnx.export(model,
+ x,
+ export_onnx_file,
+ opset_version=12,
+ input_names=["input"], # 输入名
+ output_names=["output"], # 输出名
+ dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
+ "output":{0:"batch_size"}})
+ ```
+
+## 下载预训练ONNX模型
+
+为了方便开发者的测试,下面提供了ResNet导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库)
+| 模型 | 大小 | 精度 |
+|:---------------------------------------------------------------- |:----- |:----- |
+| [ResNet-18](https://bj.bcebos.com/paddlehub/fastdeploy/resnet18.onnx) | 45MB | |
+| [ResNet-34](https://bj.bcebos.com/paddlehub/fastdeploy/resnet34.onnx) | 84MB | |
+| [ResNet-50](https://bj.bcebos.com/paddlehub/fastdeploy/resnet50.onnx) | 98MB | |
+| [ResNet-101](https://bj.bcebos.com/paddlehub/fastdeploy/resnet101.onnx) | 170MB | |
+
+
+## 详细部署文档
+
+- [Python部署](python)
+- [C++部署](cpp)
+
+## 版本说明
+
+- 本版本文档和代码基于[Torchvision v0.12.0](https://github.com/pytorch/vision/tree/v0.12.0) 编写
diff --git a/examples/vision/classification/resnet/cpp/CMakeLists.txt b/examples/vision/classification/resnet/cpp/CMakeLists.txt
new file mode 100644
index 0000000000..93540a7e83
--- /dev/null
+++ b/examples/vision/classification/resnet/cpp/CMakeLists.txt
@@ -0,0 +1,14 @@
+PROJECT(infer_demo C CXX)
+CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
+
+# 指定下载解压后的fastdeploy库路径
+option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
+
+include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
+
+# 添加FastDeploy依赖头文件
+include_directories(${FASTDEPLOY_INCS})
+
+add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
+# 添加FastDeploy库依赖
+target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
diff --git a/examples/vision/classification/resnet/cpp/README.md b/examples/vision/classification/resnet/cpp/README.md
new file mode 100644
index 0000000000..eb3bff6f48
--- /dev/null
+++ b/examples/vision/classification/resnet/cpp/README.md
@@ -0,0 +1,77 @@
+# ResNet C++部署示例
+
+本目录下提供`infer.cc`快速完成ResNet系列模型在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。
+
+在部署前,需确认以下两个步骤
+
+- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md)
+- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start)
+
+以Linux上 ResNet50 推理为例,在本目录执行如下命令即可完成编译测试
+
+```bash
+#下载SDK,编译模型examples代码(SDK中包含了examples代码)
+wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.2.1.tgz
+tar xvf fastdeploy-linux-x64-gpu-0.2.1.tgz
+cd fastdeploy-linux-x64-gpu-0.2.1/examples/vision/classification/resnet/cpp
+mkdir build
+cd build
+cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.2.1
+make -j
+
+# 下载ResNet模型文件和测试图片
+wget https://bj.bcebos.com/paddlehub/fastdeploy/resnet50.onnx
+wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
+
+
+# CPU推理
+./infer_demo resnet50.onnx ILSVRC2012_val_00000010.jpeg 0
+# GPU推理
+./infer_demo resnet50.onnx ILSVRC2012_val_00000010.jpeg 1
+# GPU上TensorRT推理
+./infer_demo resnet50.onnx ILSVRC2012_val_00000010.jpeg 2
+```
+
+以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
+- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md)
+
+## ResNet C++接口
+
+### ResNet类
+
+```c++
+
+fastdeploy::vision::classification::ResNet(
+ const std::string& model_file,
+ const std::string& params_file = "",
+ const RuntimeOption& custom_option = RuntimeOption(),
+ const ModelFormat& model_format = ModelFormat::ONNX)
+```
+
+
+**参数**
+
+> * **model_file**(str): 模型文件路径
+> * **params_file**(str): 参数文件路径
+> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
+> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式
+
+#### Predict函数
+
+> ```c++
+> ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk = 1)
+> ```
+>
+> 模型预测接口,输入图像直接输出检测结果。
+>
+> **参数**
+>
+> > * **im**: 输入图像,注意需为HWC,BGR格式
+> > * **result**: 分类结果,包括label_id,以及相应的置信度, ClassifyResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
+> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1
+
+
+- [模型介绍](../../)
+- [Python部署](../python)
+- [视觉模型预测结果](../../../../../docs/api/vision_results/)
+- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md)
diff --git a/examples/vision/classification/resnet/cpp/infer.cc b/examples/vision/classification/resnet/cpp/infer.cc
new file mode 100644
index 0000000000..083c9de60a
--- /dev/null
+++ b/examples/vision/classification/resnet/cpp/infer.cc
@@ -0,0 +1,94 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision.h"
+
+void CpuInfer(const std::string& model_file, const std::string& image_file) {
+ auto model = fastdeploy::vision::classification::ResNet(model_file);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+
+ fastdeploy::vision::ClassifyResult res;
+ if (!model.Predict(&im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ std::cout << res.Str() << std::endl;
+
+}
+
+void GpuInfer(const std::string& model_file, const std::string& image_file) {
+ auto option = fastdeploy::RuntimeOption();
+ option.UseGpu();
+ auto model = fastdeploy::vision::classification::ResNet(model_file, "", option);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+
+ fastdeploy::vision::ClassifyResult res;
+ if (!model.Predict(&im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ std::cout << res.Str() << std::endl;
+}
+
+void TrtInfer(const std::string& model_file, const std::string& image_file) {
+ auto option = fastdeploy::RuntimeOption();
+ option.UseGpu();
+ option.UseTrtBackend();
+ option.SetTrtInputShape("images", {1, 3, 224, 224});
+ auto model = fastdeploy::vision::classification::ResNet(model_file, "", option);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+
+ fastdeploy::vision::ClassifyResult res;
+ if (!model.Predict(&im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ std::cout << res.Str() << std::endl;
+}
+
+int main(int argc, char* argv[]) {
+ if (argc < 4) {
+ std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
+ "e.g ./infer_model ./resnet50.onnx ./test.jpeg 0"
+ << std::endl;
+ std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
+ "with gpu; 2: run with gpu and use tensorrt backend."
+ << std::endl;
+ return -1;
+ }
+
+ if (std::atoi(argv[3]) == 0) {
+ CpuInfer(argv[1], argv[2]);
+ } else if (std::atoi(argv[3]) == 1) {
+ GpuInfer(argv[1], argv[2]);
+ } else if (std::atoi(argv[3]) == 2) {
+ TrtInfer(argv[1], argv[2]);
+ }
+ return 0;
+}
diff --git a/examples/vision/classification/resnet/python/README.md b/examples/vision/classification/resnet/python/README.md
new file mode 100644
index 0000000000..6315ee06a6
--- /dev/null
+++ b/examples/vision/classification/resnet/python/README.md
@@ -0,0 +1,72 @@
+# ResNet模型 Python部署示例
+
+在部署前,需确认以下两个步骤
+
+- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md)
+- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start)
+
+本目录下提供`infer.py`快速完成ResNet50_vd在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
+
+```bash
+#下载部署示例代码
+git clone https://github.com/PaddlePaddle/FastDeploy.git
+cd FastDeploy/examples/vision/classification/resnet/python
+
+# 下载ResNet50_vd模型文件和测试图片
+wget https://bj.bcebos.com/paddlehub/fastdeploy/resnet50.onnx
+wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
+
+# CPU推理
+python infer.py --model resnet50.onnx --image ILSVRC2012_val_00000010.jpeg --device cpu --topk 1
+# GPU推理
+python infer.py --model resnet50.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --topk 1
+# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待)
+python infer.py --model resnet50.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1
+```
+
+运行完成后返回结果如下所示
+```bash
+ClassifyResult(
+label_ids: 332,
+scores: 0.825349,
+)
+```
+
+## ResNet Python接口
+
+```python
+fd.vision.classification.ResNet(model_file, params_file, runtime_option=None, model_format=ModelFormat.ONNX)
+```
+
+
+**参数**
+
+> * **model_file**(str): 模型文件路径
+> * **params_file**(str): 参数文件路径
+> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
+> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式
+
+### predict函数
+
+> ```python
+> ResNet.predict(input_image, topk=1)
+> ```
+>
+> 模型预测结口,输入图像直接输出检测结果。
+>
+> **参数**
+>
+> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式
+> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1
+
+> **返回**
+>
+> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
+
+
+## 其它文档
+
+- [ResNet 模型介绍](..)
+- [ResNet C++部署](../cpp)
+- [模型预测结果说明](../../../../../docs/api/vision_results/)
+- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md)
diff --git a/examples/vision/classification/resnet/python/infer.py b/examples/vision/classification/resnet/python/infer.py
new file mode 100644
index 0000000000..b8b268f3ab
--- /dev/null
+++ b/examples/vision/classification/resnet/python/infer.py
@@ -0,0 +1,50 @@
+import fastdeploy as fd
+import cv2
+import os
+
+
+def parse_arguments():
+ import argparse
+ import ast
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model", required=True, help="Path of PaddleClas model.")
+ parser.add_argument(
+ "--image", type=str, required=True, help="Path of test image file.")
+ parser.add_argument(
+ "--topk", type=int, default=1, help="Return topk results.")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default='cpu',
+ help="Type of inference device, support 'cpu' or 'gpu'.")
+ parser.add_argument(
+ "--use_trt",
+ type=ast.literal_eval,
+ default=False,
+ help="Wether to use tensorrt.")
+ return parser.parse_args()
+
+
+def build_option(args):
+ option = fd.RuntimeOption()
+
+ if args.device.lower() == "gpu":
+ option.use_gpu()
+
+ if args.use_trt:
+ option.use_trt_backend()
+ return option
+
+
+args = parse_arguments()
+
+# 配置runtime,加载模型
+runtime_option = build_option(args)
+
+model = fd.vision.classification.ResNet(
+ args.model, runtime_option=runtime_option)
+# 预测图片分类结果
+im = cv2.imread(args.image)
+result = model.predict(im.copy(), args.topk)
+print(result)
diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h
index 51fc393e5b..b83fb0f3da 100755
--- a/fastdeploy/vision.h
+++ b/fastdeploy/vision.h
@@ -17,6 +17,7 @@
#ifdef ENABLE_VISION
#include "fastdeploy/vision/classification/contrib/yolov5cls.h"
#include "fastdeploy/vision/classification/ppcls/model.h"
+#include "fastdeploy/vision/classification/contrib/resnet.h"
#include "fastdeploy/vision/detection/contrib/nanodet_plus.h"
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
#include "fastdeploy/vision/detection/contrib/yolor.h"
diff --git a/fastdeploy/vision/classification/classification_pybind.cc b/fastdeploy/vision/classification/classification_pybind.cc
index 497d692c3c..cae130de5e 100644
--- a/fastdeploy/vision/classification/classification_pybind.cc
+++ b/fastdeploy/vision/classification/classification_pybind.cc
@@ -18,11 +18,12 @@ namespace fastdeploy {
void BindYOLOv5Cls(pybind11::module& m);
void BindPaddleClas(pybind11::module& m);
-
+void BindResNet(pybind11::module& m);
void BindClassification(pybind11::module& m) {
auto classification_module =
m.def_submodule("classification", "Image classification models.");
BindYOLOv5Cls(classification_module);
BindPaddleClas(classification_module);
+ BindResNet(classification_module);
}
} // namespace fastdeploy
diff --git a/fastdeploy/vision/classification/contrib/resnet.cc b/fastdeploy/vision/classification/contrib/resnet.cc
new file mode 100644
index 0000000000..8050ceccbf
--- /dev/null
+++ b/fastdeploy/vision/classification/contrib/resnet.cc
@@ -0,0 +1,134 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision/classification/contrib/resnet.h"
+#include "fastdeploy/vision/utils/utils.h"
+#include "fastdeploy/utils/perf.h"
+
+namespace fastdeploy {
+namespace vision {
+namespace classification {
+
+ResNet::ResNet(const std::string& model_file,
+ const std::string& params_file,
+ const RuntimeOption& custom_option,
+ const ModelFormat& model_format) {
+ // In constructor, the 3 steps below are necessary.
+ // 1. set the Backend 2. set RuntimeOption 3. call Initialize()
+
+ if (model_format == ModelFormat::ONNX) {
+ valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
+ valid_gpu_backends = {Backend::ORT, Backend::TRT};
+ } else {
+ valid_cpu_backends = {Backend::PDINFER};
+ valid_gpu_backends = {Backend::PDINFER};
+ }
+ runtime_option = custom_option;
+ runtime_option.model_format = model_format;
+ runtime_option.model_file = model_file;
+ runtime_option.params_file = params_file;
+ initialized = Initialize();
+}
+
+bool ResNet::Initialize() {
+
+ // In this function, the 3 steps below are necessary.
+ // 1. assign values to the global variables 2. call InitRuntime()
+
+ size = {224, 224};
+ mean_vals = {0.485f, 0.456f, 0.406f};
+ std_vals = {0.229f, 0.224f, 0.225f};
+
+ if (!InitRuntime()) {
+ FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
+ return false;
+ }
+ return true;
+}
+
+
+bool ResNet::Preprocess(Mat* mat, FDTensor* output) {
+
+// In this function, the preprocess need be implemented according to the original Repos,
+// The result of preprocess has to be saved in FDTensor variable, because the input of Infer() need to be std::vector.
+// 1. Resize 2. BGR2RGB 3. Normalize 4. HWC2CHW 5. Put the result into FDTensor variable.
+
+ if (mat->Height()!=size[0] || mat->Width()!=size[1]){
+ int interp = cv::INTER_LINEAR;
+ Resize::Run(mat, size[1], size[0], -1, -1, interp);
+ }
+
+ BGR2RGB::Run(mat);
+ Normalize::Run(mat, mean_vals, std_vals);
+
+ HWC2CHW::Run(mat);
+ Cast::Run(mat, "float");
+ mat->ShareWithTensor(output);
+ output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
+ return true;
+}
+
+bool ResNet::Postprocess(FDTensor& infer_result,
+ ClassifyResult* result, int topk) {
+
+ // In this function, the postprocess need be implemented according to the original Repos,
+ // Finally the reslut of postprocess should be saved in ClassifyResult variable.
+ // 1. Softmax 2. Choose topk labels 3. Put the result into ClassifyResult variable.
+
+ int num_classes = infer_result.shape[1];
+ Softmax(infer_result, &infer_result);
+ const float* infer_result_buffer = reinterpret_cast(infer_result.Data());
+ topk = std::min(num_classes, topk);
+ result->label_ids =
+ utils::TopKIndices(infer_result_buffer, num_classes, topk);
+ result->scores.resize(topk);
+ for (int i = 0; i < topk; ++i) {
+ result->scores[i] = *(infer_result_buffer + result->label_ids[i]);
+ }
+ return true;
+}
+
+bool ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
+
+ // In this function, the Preprocess(), Infer(), and Postprocess() are called sequentially.
+
+ Mat mat(*im);
+ std::vector processed_data(1);
+ if (!Preprocess(&mat, &(processed_data[0]))) {
+ FDERROR << "Failed to preprocess input data while using model:"
+ << ModelName() << "." << std::endl;
+ return false;
+ }
+ processed_data[0].name = InputInfoOfRuntime(0).name;
+
+ std::vector output_tensors;
+ if (!Infer(processed_data, &output_tensors)) {
+ FDERROR << "Failed to inference while using model:" << ModelName() << "."
+ << std::endl;
+ return false;
+ }
+
+ if (!Postprocess(output_tensors[0], result, topk)) {
+ FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
+ << std::endl;
+ return false;
+ }
+
+ return true;
+}
+
+
+} // namespace classification
+} // namespace vision
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/classification/contrib/resnet.h b/fastdeploy/vision/classification/contrib/resnet.h
new file mode 100644
index 0000000000..f766f1bf50
--- /dev/null
+++ b/fastdeploy/vision/classification/contrib/resnet.h
@@ -0,0 +1,74 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include "fastdeploy/fastdeploy_model.h"
+#include "fastdeploy/vision/common/processors/transform.h"
+#include "fastdeploy/vision/common/result.h"
+
+// The namespace shoulde be
+// fastdeploy::vision::classification (fastdeploy::vision::${task})
+namespace fastdeploy {
+namespace vision {
+/** \brief All object classification model APIs are defined inside this namespace
+ *
+ */
+namespace classification {
+/*! @brief ResNet series model
+ */
+class FASTDEPLOY_DECL ResNet : public FastDeployModel {
+ public:
+ /** \brief Set path of model file and the configuration of runtime.
+ *
+ * \param[in] model_file Path of model file, e.g ./resnet50.onnx
+ * \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored
+ * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
+ * \param[in] model_format Model format of the loaded model, default is ONNX format
+ */
+ ResNet(const std::string& model_file,
+ const std::string& params_file = "",
+ const RuntimeOption& custom_option = RuntimeOption(),
+ const ModelFormat& model_format = ModelFormat::ONNX);
+
+ virtual std::string ModelName() const { return "ResNet"; }
+ /** \brief Predict for the input "im", the result will be saved in "result".
+ *
+ * \param[in] im Input image for inference.
+ * \param[in] result Saving the inference result.
+ * \param[in] topk The length of return values, e.g., if topk==2, the result will include the 2 most possible class label for input image.
+ */
+ virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1);
+
+ /// Tuple of (width, height)
+ std::vector size;
+ /// Mean parameters for normalize
+ std::vector mean_vals;
+ /// Std parameters for normalize
+ std::vector std_vals;
+
+
+ private:
+ /*! @brief Initialize for ResNet model, assign values to the global variables and call InitRuntime()
+ */
+ bool Initialize();
+ /// PreProcessing for the input "mat", the result will be saved in "outputs".
+ bool Preprocess(Mat* mat, FDTensor* outputs);
+ /*! @brief PostProcessing for the input "infer_result", the result will be saved in "result".
+ */
+ bool Postprocess(FDTensor& infer_result, ClassifyResult* result,
+ int topk = 1);
+};
+} // namespace classification
+} // namespace vision
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/classification/contrib/resnet_pybind.cc b/fastdeploy/vision/classification/contrib/resnet_pybind.cc
new file mode 100644
index 0000000000..7654d5bdca
--- /dev/null
+++ b/fastdeploy/vision/classification/contrib/resnet_pybind.cc
@@ -0,0 +1,40 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/pybind/main.h"
+// namespace should be `fastdeploy`
+namespace fastdeploy {
+// the name of Pybind function should be Bind${model_name}
+void BindResNet(pybind11::module& m) {
+ // the constructor and the predict funtion are necessary
+ // the constructor is used to initialize the python model class.
+ // the necessary public functions and variables like `size`, `mean_vals` should also be binded.
+ pybind11::class_(
+ m, "ResNet")
+ .def(pybind11::init())
+ .def("predict",
+ [](vision::classification::ResNet& self, pybind11::array& data,
+ int topk = 1) {
+ auto mat = PyArrayToCvMat(data);
+ vision::ClassifyResult res;
+ self.Predict(&mat, &res, topk);
+ return res;
+ })
+ .def_readwrite("size", &vision::classification::ResNet::size)
+ .def_readwrite("mean_vals", &vision::classification::ResNet::mean_vals)
+ .def_readwrite("std_vals", &vision::classification::ResNet::std_vals);
+}
+} // namespace fastdeploy
+
diff --git a/python/fastdeploy/vision/classification/__init__.py b/python/fastdeploy/vision/classification/__init__.py
index ceeaa024a7..0b426fab13 100644
--- a/python/fastdeploy/vision/classification/__init__.py
+++ b/python/fastdeploy/vision/classification/__init__.py
@@ -15,7 +15,7 @@
from .contrib.yolov5cls import YOLOv5Cls
from .ppcls import PaddleClasModel
-
+from .contrib.resnet import ResNet
PPLCNet = PaddleClasModel
PPLCNetv2 = PaddleClasModel
EfficientNet = PaddleClasModel
diff --git a/python/fastdeploy/vision/classification/contrib/resnet.py b/python/fastdeploy/vision/classification/contrib/resnet.py
new file mode 100644
index 0000000000..52f45933ba
--- /dev/null
+++ b/python/fastdeploy/vision/classification/contrib/resnet.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import logging
+from .... import FastDeployModel, ModelFormat
+from .... import c_lib_wrap as C
+
+
+class ResNet(FastDeployModel):
+ def __init__(self,
+ model_file,
+ params_file="",
+ runtime_option=None,
+ model_format=ModelFormat.ONNX):
+ """Load a image classification model exported by ResNet.
+
+ :param model_file: (str)Path of model file, e.g resnet/resnet50.onnx
+ :param params_file: (str)Path of parameters file, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
+ :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
+ :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model, default is ONNX
+ """
+
+ # call super() to initialize the backend_option
+ # the result of initialization will be saved in self._runtime_option
+ super(ResNet, self).__init__(runtime_option)
+
+ self._model = C.vision.classification.ResNet(
+ model_file, params_file, self._runtime_option, model_format)
+ # self.initialized shows the initialization of the model is successful or not
+
+ assert self.initialized, "ResNet initialize failed."
+
+ # Predict and return the inference result of "input_image".
+ def predict(self, input_image, topk=1):
+ """Classify an input image
+
+ :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
+ :param topk: (int)The topk result by the classify confidence score, default 1
+ :return: ClassifyResult
+ """
+ return self._model.predict(input_image, topk)
+
+ # Implement the setter and getter method for variables
+ @property
+ def size(self):
+ """
+ Returns the preprocess image size
+ """
+ return self._model.size
+
+ @property
+ def mean_vals(self):
+ """
+ Returns the mean value of normlization
+ """
+ return self._model.mean_vals
+
+ @property
+ def std_vals(self):
+ """
+ Returns the std value of normlization
+ """
+ return self._model.std_vals
+
+ @size.setter
+ def size(self, wh):
+ assert isinstance(wh, (list, tuple)),\
+ "The value to set `size` must be type of tuple or list."
+ assert len(wh) == 2,\
+ "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
+ len(wh))
+ self._model.size = wh
+
+ @mean_vals.setter
+ def mean_vals(self, value):
+ assert isinstance(
+ value, list), "The value to set `mean_vals` must be type of list."
+ self._model.mean_vals = value
+
+ @std_vals.setter
+ def std_vals(self, value):
+ assert isinstance(
+ value, list), "The value to set `std_vals` must be type of list."
+ self._model.std_vals = value