diff --git a/deploy/cpp/run_seg_cpu.sh b/deploy/cpp/run_seg_cpu.sh index 82ceb6cca9..ea94915bab 100755 --- a/deploy/cpp/run_seg_cpu.sh +++ b/deploy/cpp/run_seg_cpu.sh @@ -29,7 +29,7 @@ make -j cd .. ./build/test_seg \ - --model_dir=./bisenetv2_demo_model \ + --model_dir=./stdc1seg_infer_model \ --img_path=./cityscapes_demo.png \ - --use_cpu=true \ + --devices=CPU \ --use_mkldnn=true diff --git a/deploy/cpp/run_seg_gpu.sh b/deploy/cpp/run_seg_gpu.sh index 9873260cc0..b80547ad6c 100755 --- a/deploy/cpp/run_seg_gpu.sh +++ b/deploy/cpp/run_seg_gpu.sh @@ -29,6 +29,6 @@ make -j cd .. ./build/test_seg \ - --model_dir=./bisenetv2_demo_model \ + --model_dir=./stdc1seg_infer_model \ --img_path=./cityscapes_demo.png \ - --use_cpu=false + --devices=GPU diff --git a/deploy/cpp/run_seg_gpu_trt.sh b/deploy/cpp/run_seg_gpu_trt.sh new file mode 100644 index 0000000000..9bebc9e4d4 --- /dev/null +++ b/deploy/cpp/run_seg_gpu_trt.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set +x +set -e + +# set TENSORRT_ROOT +TENSORRT_ROOT='/work/download/TensorRT-7.1.3.4/' + +WITH_MKL=ON +WITH_GPU=ON +USE_TENSORRT=ON +DEMO_NAME=test_seg + +work_path=$(dirname $(readlink -f $0)) +LIB_DIR="${work_path}/paddle_inference" + +# compile +mkdir -p build +cd build +rm -rf * + +cmake .. \ + -DDEMO_NAME=${DEMO_NAME} \ + -DWITH_MKL=${WITH_MKL} \ + -DWITH_GPU=${WITH_GPU} \ + -DUSE_TENSORRT=${USE_TENSORRT} \ + -DWITH_STATIC_LIB=OFF \ + -DPADDLE_LIB=${LIB_DIR} \ + -DTENSORRT_ROOT=${TENSORRT_ROOT} + +make -j + +# run +cd .. + +./build/test_seg \ + --model_dir=./stdc1seg_infer_model \ + --img_path=./cityscapes_demo.png \ + --devices=GPU \ + --use_trt=True \ + --trt_precision=fp32 diff --git a/deploy/cpp/run_seg_gpu_trt_dynamic_shape.sh b/deploy/cpp/run_seg_gpu_trt_dynamic_shape.sh new file mode 100644 index 0000000000..ebfea5e81b --- /dev/null +++ b/deploy/cpp/run_seg_gpu_trt_dynamic_shape.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set +x +set -e + +WITH_MKL=ON +WITH_GPU=ON +USE_TENSORRT=ON +DEMO_NAME=test_seg + +work_path=$(dirname $(readlink -f $0)) +LIB_DIR="${work_path}/paddle_inference" + +# set TENSORRT_ROOT and dynamic_shape_path +TENSORRT_ROOT='/work/download/TensorRT-7.1.3.4/' +DYNAMIC_SHAPE_PATH='./dynamic_shape.pbtxt' +TRT_PRECISION=fp32 + +# compile +mkdir -p build +cd build +rm -rf * + +cmake .. \ + -DDEMO_NAME=${DEMO_NAME} \ + -DWITH_MKL=${WITH_MKL} \ + -DWITH_GPU=${WITH_GPU} \ + -DUSE_TENSORRT=${USE_TENSORRT} \ + -DWITH_STATIC_LIB=OFF \ + -DPADDLE_LIB=${LIB_DIR} \ + -DTENSORRT_ROOT=${TENSORRT_ROOT} + +make -j + +# run +cd .. + +./build/test_seg \ + --model_dir=./stdc1seg_infer_model \ + --img_path=./cityscapes_demo.png \ + --devices=GPU \ + --use_trt=True \ + --trt_precision=${TRT_PRECISION} \ + --use_trt_dynamic_shape=True \ + --dynamic_shape_path=${DYNAMIC_SHAPE_PATH} diff --git a/deploy/cpp/src/test_seg.cc b/deploy/cpp/src/test_seg.cc index 17d5eb7725..90acbf3639 100644 --- a/deploy/cpp/src/test_seg.cc +++ b/deploy/cpp/src/test_seg.cc @@ -16,8 +16,11 @@ DEFINE_string(model_dir, "", "Directory of the inference model. " "It constains deploy.yaml and infer models"); DEFINE_string(img_path, "", "Path of the test image."); -DEFINE_bool(use_cpu, false, "Wether use CPU. Default: use GPU."); +DEFINE_string(devices, "GPU", "Use GPU or CPU devices. Default: GPU"); DEFINE_bool(use_trt, false, "Wether enable TensorRT when use GPU. Defualt: false."); +DEFINE_string(trt_precision, "fp32", "The precision of TensorRT, support fp32, fp16 and int8. Default: fp32"); +DEFINE_bool(use_trt_dynamic_shape, false, "Wether enable dynamic shape when use GPU and TensorRT. Defualt: false."); +DEFINE_string(dynamic_shape_path, "", "If set dynamic_shape_path, it read the dynamic shape for TRT."); DEFINE_bool(use_mkldnn, false, "Wether enable MKLDNN when use CPU. Defualt: false."); DEFINE_string(save_dir, "", "Directory of the output image."); @@ -60,20 +63,55 @@ std::shared_ptr create_predictor( model_dir + "/" + yaml_config.params_file); infer_config.EnableMemoryOptim(); - if (FLAGS_use_cpu) { + if (FLAGS_devices == "CPU") { LOG(INFO) << "Use CPU"; if (FLAGS_use_mkldnn) { - // TODO(jc): fix the bug - //infer_config.EnableMKLDNN(); + LOG(INFO) << "Use MKLDNN"; + infer_config.EnableMKLDNN(); infer_config.SetCpuMathLibraryNumThreads(5); } - } else { + } else if(FLAGS_devices == "GPU") { LOG(INFO) << "Use GPU"; infer_config.EnableUseGpu(100, 0); + + // TRT config if (FLAGS_use_trt) { - infer_config.EnableTensorRtEngine(1 << 20, 1, 3, - paddle_infer::PrecisionType::kFloat32, false, false); + LOG(INFO) << "Use TRT"; + LOG(INFO) << "trt_precision:" << FLAGS_trt_precision; + + // TRT precision + if (FLAGS_trt_precision == "fp32") { + infer_config.EnableTensorRtEngine(1 << 20, 1, 3, + paddle_infer::PrecisionType::kFloat32, false, false); + } else if (FLAGS_trt_precision == "fp16") { + infer_config.EnableTensorRtEngine(1 << 20, 1, 3, + paddle_infer::PrecisionType::kHalf, false, false); + } else if (FLAGS_trt_precision == "int8") { + infer_config.EnableTensorRtEngine(1 << 20, 1, 3, + paddle_infer::PrecisionType::kInt8, false, false); + } else { + LOG(FATAL) << "The trt_precision should be fp32, fp16 or int8."; + } + + // TRT dynamic shape + if (FLAGS_use_trt_dynamic_shape) { + LOG(INFO) << "Enable TRT dynamic shape"; + if (FLAGS_dynamic_shape_path.empty()) { + std::map> min_input_shape = { + {"image", {1, 3, 112, 112}}}; + std::map> max_input_shape = { + {"image", {1, 3, 1024, 2048}}}; + std::map> opt_input_shape = { + {"image", {1, 3, 512, 1024}}}; + infer_config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, + opt_input_shape); + } else { + infer_config.EnableTunedTensorRtDynamicShape(FLAGS_dynamic_shape_path, true); + } + } } + } else { + LOG(FATAL) << "The devices should be GPU or CPU"; } auto predictor = paddle_infer::CreatePredictor(infer_config); @@ -153,5 +191,5 @@ int main(int argc, char *argv[]) { cv::equalizeHist(out_gray_img, out_eq_img); cv::imwrite("out_img.jpg", out_eq_img); - LOG(INFO) << "Finish"; + LOG(INFO) << "Finish, the result is saved in out_img.jpg"; } diff --git a/deploy/python/collect_dynamic_shape.py b/deploy/python/collect_dynamic_shape.py new file mode 100644 index 0000000000..121d9d04e9 --- /dev/null +++ b/deploy/python/collect_dynamic_shape.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import codecs +import os +import sys + +import yaml +import numpy as np +from paddle.inference import create_predictor, PrecisionType +from paddle.inference import Config as PredictConfig + +LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(LOCAL_PATH, '..', '..')) + +from paddleseg.utils import logger, get_image_list, progbar +from infer import DeployConfig +""" +Load images and run the model, it collects and saves dynamic shapes, +which are used in deployment with TRT. +""" + + +def parse_args(): + parser = argparse.ArgumentParser(description='Test') + parser.add_argument( + "--config", + help="The deploy config generated by exporting model.", + type=str, + required=True) + parser.add_argument( + '--image_path', + help='The directory or path or file list of the images to be predicted.', + type=str, + required=True) + + parser.add_argument( + '--dynamic_shape_path', + type=str, + default="./dynamic_shape.pbtxt", + help='The path to save dynamic shape.') + + return parser.parse_args() + + +def is_support_collecting(): + return hasattr(PredictConfig, "collect_shape_range_info") \ + and hasattr(PredictConfig, "enable_tuned_tensorrt_dynamic_shape") + + +def collect_dynamic_shape(args): + + if not is_support_collecting(): + logger.error("The Paddle does not support collecting dynamic shape, " \ + "please reinstall the PaddlePaddle (latest gpu version).") + + # prepare config + cfg = DeployConfig(args.config) + pred_cfg = PredictConfig(cfg.model, cfg.params) + pred_cfg.enable_use_gpu(1000, 0) + pred_cfg.collect_shape_range_info(args.dynamic_shape_path) + + # create predictor + predictor = create_predictor(pred_cfg) + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + + # get images + img_path_list, _ = get_image_list(args.image_path) + if not isinstance(img_path_list, (list, tuple)): + img_path_list = [img_path_list] + logger.info(f"The num of images is {len(img_path_list)} \n") + + # collect + progbar_val = progbar.Progbar(target=len(img_path_list)) + for idx, img_path in enumerate(img_path_list): + data = np.array([cfg.transforms(img_path)[0]]) + input_handle.reshape(data.shape) + input_handle.copy_from_cpu(data) + + try: + predictor.run() + except: + logger.info( + "Fail to collect dynamic shape. Usually, the error is out of " + "GPU memory, for the model and image are too large.\n") + del predictor + if os.path.exists(args.dynamic_shape_path): + os.remove(args.dynamic_shape_path) + + progbar_val.update(idx + 1) + + logger.info(f"The dynamic shape is save in {args.dynamic_shape_path}") + + +if __name__ == '__main__': + args = parse_args() + collect_dynamic_shape(args) diff --git a/docs/deployment/inference/cpp_inference_cn.md b/docs/deployment/inference/cpp_inference_cn.md index 4efc998192..033ac536fe 100644 --- a/docs/deployment/inference/cpp_inference_cn.md +++ b/docs/deployment/inference/cpp_inference_cn.md @@ -7,25 +7,45 @@ * 准备模型和图片 * 编译、执行 -飞桨针对不同场景,提供了多个预测引擎部署模型(如下图),详细信息请参考[文档](https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html)。 +飞桨针对不同场景,提供了多个预测引擎部署模型(如下图),详细使用方法请参考[文档](https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html)。 ![inference_ecosystem](https://user-images.githubusercontent.com/52520497/130720374-26947102-93ec-41e2-8207-38081dcc27aa.png) 此外,PaddleX也提供了C++部署的示例和文档,具体参考[链接](https://github.com/PaddlePaddle/PaddleX/tree/develop/deploy/cpp)。 ## 2. 准备环境 +### 2.1 准备基础环境 -### 准备Paddle Inference C++预测库 +如果使用PaddleInference在X86 CPU部署模型,可以跳过如下CUDA、cudnn、TensorRT准备工作。 -大家可以从[链接](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html)下载Paddle Inference C++预测库。 +如果使用PaddleInference在Nvidia GPU上部署模型,需要安装必要的CUDA、cudnn。此外,PaddleInference在Nvidia GPU上支持使用TensorRT进行加速,可以视具体情况下载TRT库文件。 -注意根据机器的CUDA版本、cudnn版本、使用MKLDNN或者OpenBlas、是否使用TenorRT等信息,选择准确版本。建议选择版本>=2.0.1的预测库。 +注意按照PaddleInference提供的[C++预测库](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html),选择支持的CUDA、cudnn、TensorRT版本。 + +CUDA和cudnn安装方法,可以参考网上文档或者官方文档([Cuda](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/), [cudnn](https://docs.nvidia.com/deeplearning/cudnn/install-guide/))。 + +此处,我们提供两个版本的CUDA、cudnn、TensorRT文件下载。 +``` +wget https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.1-cudnn7.6-trt6.0.tar +wget https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.2-cudnn8.0-trt7.1.tgz +``` + +下载解压后,参考文档安装CUDA和cudnn,TensorRT只需要设置库路径,比如: +``` +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/work/TensorRT-7.1.3.4/lib +``` + +### 2.2 准备Paddle Inference C++预测库 + +PaddleInference提供了多种版本的预编译[C++预测库](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html)。 + +不同C++预测库可以根据名字进行区分。请根据机器的操作系统、CUDA版本、cudnn版本、使用MKLDNN或者OpenBlas、是否使用TenorRT等信息,选择准确版本。(建议选择版本>=2.0的预测库) 下载`paddle_inference.tgz`压缩文件后进行解压,将解压的paddle_inference文件保存到`PaddleSeg/deploy/cpp/`下。 如果大家需要编译Paddle Inference C++预测库,可以参考[文档](https://paddleinference.paddlepaddle.org.cn/user_guides/source_compile.html),此处不再赘述。 -### 准备OpenCV +### 2.3 准备OpenCV 本示例使用OpenCV读取图片,所以需要准备OpenCV。 @@ -45,7 +65,7 @@ make install cd ../.. ``` -### 安装Yaml +### 2.4 安装Yaml 本示例使用Yaml读取配置文件信息。 @@ -63,11 +83,11 @@ make install ## 3. 准备模型和图片 -在`PaddleSeg/deploy/cpp/`目录下执行如下命令,下载[测试模型](https://paddleseg.bj.bcebos.com/dygraph/demo/bisenet_demo_model.tar.gz)用于测试。如果需要测试其他模型,请参考[文档](../../model_export.md)导出预测模型。 +在`PaddleSeg/deploy/cpp/`目录下执行如下命令,下载[测试模型](https://paddleseg.bj.bcebos.com/dygraph/demo/stdc1seg_infer_model.tar.gz)用于测试。如果需要测试其他模型,请参考[文档](../../model_export.md)导出预测模型。 ``` -wget https://paddleseg.bj.bcebos.com/dygraph/demo/bisenet_demo_model.tar.gz -tar xzf bisenet_demo_model.tar.gz +wget https://paddleseg.bj.bcebos.com/dygraph/demo/stdc1seg_infer_model.tar.gz +tar xf stdc1seg_infer_model.tar.gz ``` 下载cityscapes验证集中的一张[图片](https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png)。 @@ -76,21 +96,77 @@ tar xzf bisenet_demo_model.tar.gz wget https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png ``` -## 4. 编译、执行 - 请检查`PaddleSeg/deploy/cpp/`下存放了预测库、模型、图片,如下。 ``` PaddleSeg/deploy/cpp |-- paddle_inference # 预测库 -|-- bisenetv2_demo_model # 模型 +|-- stdc1seg_infer_model # 模型 |-- cityscapes_demo.png # 图片 +... ``` -执行`sh run_seg_cpu.sh`,会进行编译,然后在X86 CPU上执行预测。 +## 4. X86 CPU上部署 + +执行`sh run_seg_cpu.sh`,会进行编译,然后在X86 CPU上执行预测,分割结果会保存在当前目录的“out_img.jpg“图片。 +## 5. Nvidia GPU上部署 + +在Nvidia GPU上部署模型,我们需要提前明确部署场景和要求,主要关注多次预测时输入图像的尺寸是否变化。 + +定义:固定shape模式是指多次预测时输入图像的尺寸是不变的,动态shape模式是指每次预测时输入图像的尺寸可以变化。 + +飞桨PaddleInference在Nvidia GPU上部署模型,支持两种方式: +* Naive方式:使用Paddle自实现的Kernel执行预测;它使用相同的配置方法支持固定shape模式和动态shape模式。 +* TRT方式:使用集成的TensorRT执行预测,通常TRT方式比Naive方式速度更快;它使用不同的配置方法支持固定shape模式和动态shape模式。 + +### 5.1 Naive方式-部署 + +如果使用Naive方式部署Seg分割模型(固定Shape模式或者动态Shape模式),可以执行`sh run_seg_gpu.sh`。 + +该脚本会进行编译、加载模型、加载图片、执行预测,结果保存在“out_img.jpg“图片。 + +### 5.2 TRT方式-固定Shape模式-部署 + +使用TRT方式、固定Shape模式来部署PaddleSeg分割模型: +* 打开`run_seg_gpu_trt.sh`脚本,设置`TENSORRT_ROOT`为机器中TensorRT库的路径,比如`TENSORRT_ROOT='/work/TensorRT-7.1.3.4/'`。 +* 执行 `sh run_seg_gpu_trt.sh`。 +* 预测结果会保存在“out_img.jpg“图片。 + +对于PaddleSeg分割模型,通常是支持任意输入size,模型内部存在动态Shape的OP。 +所以使用TRT方式、固定Shape模式来部署时,经常会出现错误。这时就推荐使用TRT方式、动态Shape模式来部署。 + +### 5.3 TRT方式-动态Shape模式-部署 + +PaddleInference有多种方法使用TRT方式、固定Shape模式来部署PaddleSeg分割模型,此处推荐一种通用性较强的方法,主要步骤包括:准备预测模型和样本图像;离线收集动态Shape;部署执行。 + +* 准备预测模型和样本图像 + +准备预测模型和样本图像,是用于离线收集动态Shape,所以**准备的样本图像需要包含实际预测时会遇到的最大和最小图像尺寸**。 + +在前面步骤,我们已经准备好预测模型和一张测试图片。 + +* 离线收集动态Shape + +请参考PaddleSeg[安装文档](../../install_cn.md)安装PaddlePaddle和PaddleSeg的依赖项。 + +在`PaddleSeg/deploy/cpp`路径下,执行如下命令。 +``` +python ../python/collect_dynamic_shape.py \ + --config stdc1seg_infer_model/deploy.yaml \ + --image_path ./cityscapes_demo.png \ + --dynamic_shape_path ./dynamic_shape.pbtxt +``` + +通过指定预测模型config文件和样本图像,脚本会加载模型、读取样本图像、统计并保存动态Shape到`./dynamic_shape.pbtxt`文件。 + +如果有多张样本图像,可以通过`--image_path`指定图像文件夹。 + +* 部署执行 + +打开`run_seg_gpu_trt_dynamic_shape.sh`脚本,设置`TENSORRT_ROOT`为机器上TensorRT库的路径,设置`DYNAMIC_SHAPE_PATH`为动态Shape文件。 -执行`sh run_seg_gpu.sh`,会进行编译,然后在Nvidia GPU上执行预测。 +执行`sh run_seg_gpu_trt_dynamic_shape.sh`,预测结果会保存在“out_img.jpg“图片。 -分割结果会保存在当前目录的“out_img.jpg“图片,如下图。注意,该图片是使用了直方图均衡化,便于可视化。 +结果如下图,该图片使用了直方图均衡化,便于可视化。 ![out_img](https://user-images.githubusercontent.com/52520497/131456277-260352b5-4047-46d5-a38f-c50bbcfb6fd0.jpg)