diff --git a/configs/mask_rtdetr/README.md b/configs/mask_rtdetr/README.md new file mode 100644 index 00000000000..df95fb4d7b5 --- /dev/null +++ b/configs/mask_rtdetr/README.md @@ -0,0 +1,217 @@ +English | [简体中文](README_cn.md) + +# Mask RT-DETR + +## Table of Contents +- [Introduction](#introduction) +- [Model Zoo](#model-zoo) +- [Getting Start](#getting-start) +- [More Usage](#more-usage) + +## Introduction +Mask RT-DETR is an instance segmentation version of [RT DETR](../rtdetr/README.md). + +## Model Zoo +| Model | Epoch | Backbone | Input shape | Box AP | Mask AP | Params(M) | FLOPs(G) | T4 TensorRT FP16(FPS) | Pretrained Model | config | +|:--------------:|:-----:|:--------:|:-----------:|:------:|:-------:|:---------:|:--------:|:---------------------:|:--------------------------------------------------------------------------------------:|:-------------------------------------------:| +| Mask-RT-DETR-L | 6x | HGNetv2 | 640 | 51.2 | 45.7 | 32 | 120 | 90 | [model](https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams) | [config](mask_rtdetr_hgnetv2_l_6x_coco.yml) | + + +## Getting Start + +### Datasets and Metrics + +PaddleDetection team provides **COCO dataset** , decompress and place it under `PaddleDetection/dataset/`: + +``` +wget https://bj.bcebos.com/v1/paddledet/data/coco.tar +# tar -xvf coco.tar +``` + +**Note:** + - For the format of COCO style dataset, please refer to [format-data](https://cocodataset.org/#format-data) and [format-results](https://cocodataset.org/#format-results). + - For the evaluation metric of COCO, please refer to [detection-eval](https://cocodataset.org/#detection-eval), and install [cocoapi](https://github.com/cocodataset/cocoapi) at first. + +### Custom dataset + +1.For the annotation of custom dataset, please refer to [DetAnnoTools](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/docs/tutorials/data/DetAnnoTools_en.md); + +2.For training preparation of custom dataset,please refer to [PrepareDataSet](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/docs/tutorials/data/PrepareDetDataSet_en.md). + + +### Training + +Training Mask RT-DETR with following command + +```bash +# training on a single GPU +export CUDA_VISIBLE_DEVICES=0 +python tools/train.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --amp --eval + +# training on multi GPUs +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --amp --eval +``` +**Notes:** +- If you need to evaluate while training, please add `--eval`. +- Mask RT-DETR supports mixed precision training, please add `--amp`. +- PaddleDetection supports multi-machine distributed training, you can refer to [DistributedTraining tutorial](../../docs/tutorials/DistributedTraining_en.md). + + +### Evaluation + +Evaluating Mask RT-DETR on COCO val2017 dataset in single GPU with following commands: + +```bash +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams +``` + +For evaluation on COCO test-dev2017 dataset, please download COCO test-dev2017 dataset from [COCO dataset download](https://cocodataset.org/#download) and decompress to COCO dataset directory and configure `EvalDataset` like `configs/ppyolo/ppyolo_test.yml`. + +### Inference + +Inference images in single GPU with following commands, use `--infer_img` to inference a single image and `--infer_dir` to inference all images in the directory. + + +```bash +# inference single image +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams --infer_img=demo/000000014439_640x640.jpg + +# inference all images in the directory +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams --infer_dir=demo +``` + +### Exporting models + +For deployment on GPU or speed testing, model should be first exported to inference model using `tools/export_model.py`. + +**Exporting Mask RT-DETR for Paddle Inference without TensorRT**, use following command + +```bash +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams +``` + +**Exporting Mask RT-DETR for Paddle Inference with TensorRT** for better performance, use following command with extra `-o trt=True` setting. + +```bash +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True +``` + +If you want to export Mask RT-DETR model to **ONNX format**, use following command refer to [PaddleDetection Model Export as ONNX Format Tutorial](../../deploy/EXPORT_ONNX_MODEL_en.md). + +```bash + +# export inference model +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True + +# install paddle2onnx +pip install paddle2onnx + +# convert to onnx +paddle2onnx --model_dir output_inference/mask_rtdetr_hgnetv2_l_6x_coco --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 16 --save_file mask_rtdetr_hgnetv2_l_6x_coco.onnx +``` + +**Notes:** ONNX model only supports batch_size=1 now + +### Speed testing + +**Using Paddle Inference with TensorRT** to test speed, run following command + +```bash +# export inference model with trt=True +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True + +# convert to onnx +paddle2onnx --model_dir output_inference/mask_rtdetr_hgnetv2_l_6x_coco --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 16 --save_file mask_rtdetr_hgnetv2_l_6x_coco.onnx + +``` + +Fix the previously exported ONNX model's `im_shape` and `scale_factor` two input data, code as follows: +```python +# onnx_edit.py + +import copy + +import onnx + +if __name__ == '__main__': + model_path = './mask_rtdetr_hgnetv2_l_6x_coco.onnx' + model = onnx.load_model(model_path) + + im_shape = onnx.helper.make_tensor( + name='im_shape', + data_type=onnx.helper.TensorProto.FLOAT, + dims=[1, 2], + vals=[640, 640]) + scale_factor = onnx.helper.make_tensor( + name='scale_factor', + data_type=onnx.helper.TensorProto.FLOAT, + dims=[1, 2], + vals=[1, 1]) + + new_model = copy.deepcopy(model) + + for input in model.graph.input: + if input.name == 'im_shape': + new_model.graph.input.remove(input) + new_model.graph.initializer.append(im_shape) + + if input.name == 'scale_factor': + new_model.graph.input.remove(input) + new_model.graph.initializer.append(scale_factor) + + onnx.checker.check_model(new_model, full_check=True) + onnx.save_model(new_model, model_path) +``` + +Simplify the onnx model using onnxsim: + +```shell +pip install onnxsim +onnxsim mask_rtdetr_hgnetv2_l_6x_coco.onnx mask_rtdetr_hgnetv2_l_6x_coco.onnx --overwrite-input-shape "image:1,3,640,640" +``` + +```shell +# trt inference using fp16 and batch_size=1 +trtexec --onnx=./mask_rtdetr_hgnetv2_l_6x_coco.onnx --saveEngine=./mask_rtdetr_hgnetv2_l_6x_coco.engine --workspace=4096 --avgRuns=1000 --fp16 +``` + + +### Deployment + +Mask RT-DETR can be deployed by following approaches: + - Paddle Inference [Python](../../deploy/python) & [C++](../../deploy/cpp) + +Next, we will introduce how to use Paddle Inference to deploy Mask RT-DETR models in TensorRT FP16 mode. + +First, refer to [Paddle Inference Docs](https://www.paddlepaddle.org.cn/inference/master/user_guides/download_lib.html#python), download and install packages corresponding to CUDA, CUDNN and TensorRT version. + +Then, Exporting Mask RT-DETR for Paddle Inference **with TensorRT**, use following command. + +```bash +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True +``` + +Finally, inference in TensorRT FP16 mode. + +```bash +# inference single image +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/mask_rtdetr_hgnetv2_l_6x_coco --image_file=demo/000000014439_640x640.jpg --device=gpu --run_mode=trt_fp16 + +# inference all images in the directory +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/mask_rtdetr_hgnetv2_l_6x_coco --image_dir=demo/ --device=gpu --run_mode=trt_fp16 + +``` + +**Notes:** +- TensorRT will perform optimization for the current hardware platform according to the definition of the network, generate an inference engine and serialize it into a file. This inference engine is only applicable to the current hardware hardware platform. If your hardware and software platform has not changed, you can set `use_static=True` in [enable_tensorrt_engine](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/python/infer.py#L660). In this way, the serialized file generated will be saved in the `output_inference` folder, and the saved serialized file will be loaded the next time when TensorRT is executed. + + +## More Usage + +### Setting the Model Weight Saving Metric +When using the COCO-style evaluation metric for instance segmentation tasks, if the inference results contain both `box` and `mask`, the default metric for saving the best model weight is `box ap`. You can change the saving metric to `mask ap` by adding the `target_metrics` parameter in the configuration file and assigning it to `mask`. For a specific configuration file example, please refer to [mask_rtdetr_r50vd.yml](./_base_/mask_rtdetr_r50vd.yml). + +### Setting the Number of Queries +- The number of queries in the post-processing can be adjusted by modifying the `num_top_queries` parameter of `DETRPostProcess`. The default value of `num_top_queries` for Mask RT-DETR post-processing is 100. +- The number of queries in the Query Selection part can be adjusted by modifying the `num_queries` parameter of `MaskRTDETR`. The default value of `num_queries` for Mask RT-DETR is 300. Since the classification loss of Mask RT-DETR aligns the mask IoU and the confidence, we can set the `num_queries` of Mask RT-DETR to 100 and directly load the weights trained with `num_queries` of 300. This can improve the inference speed with little loss of accuracy. \ No newline at end of file diff --git a/configs/mask_rtdetr/README_cn.md b/configs/mask_rtdetr/README_cn.md new file mode 100644 index 00000000000..4f8a97af1cd --- /dev/null +++ b/configs/mask_rtdetr/README_cn.md @@ -0,0 +1,216 @@ +简体中文 | [English](README.md) + +# Mask RT-DETR + +## 内容 +- [简介](#简介) +- [模型库](#模型库) +- [使用说明](#使用说明) +- [更多用法](#更多用法) + +## 简介 +Mask-RT-DETR是[RT-DETR](../rtdetr/README.md)的实例分割版本。 + +## 模型库 +| Model | Epoch | Backbone | Input shape | Box AP | Mask AP | Params(M) | FLOPs(G) | T4 TensorRT FP16(FPS) | Pretrained Model | config | +|:--------------:|:-----:|:--------:|:-----------:|:------:|:-------:|:---------:|:--------:|:---------------------:|:--------------------------------------------------------------------------------------:|:-------------------------------------------:| +| Mask-RT-DETR-L | 6x | HGNetv2 | 640 | 51.2 | 45.7 | 32 | 120 | 90 | [model](https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams) | [config](mask_rtdetr_hgnetv2_l_6x_coco.yml) | + + +## 使用说明 + +### 数据集和评价指标 + +下载PaddleDetection团队提供的**COCO数据**,并解压放置于`PaddleDetection/dataset/`下: + +``` +wget https://bj.bcebos.com/v1/paddledet/data/coco.tar +# tar -xvf coco.tar +``` + +**注意:** + - COCO风格格式,请参考 [format-data](https://cocodataset.org/#format-data) 和 [format-results](https://cocodataset.org/#format-results)。 + - COCO风格评测指标,请参考 [detection-eval](https://cocodataset.org/#detection-eval) ,并首先安装 [cocoapi](https://github.com/cocodataset/cocoapi)。 + +### 自定义数据集 + +1.自定义数据集的标注制作,请参考 [DetAnnoTools](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/docs/tutorials/data/DetAnnoTools.md); +2.自定义数据集的训练准备,请参考 [PrepareDataSet](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/docs/tutorials/data/PrepareDetDataSet.md). + + +### 训练 + +请执行以下指令训练Mask RT-DETR + +```bash +# 单卡GPU上训练 +export CUDA_VISIBLE_DEVICES=0 +python tools/train.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --amp --eval + +# 多卡GPU上训练 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --amp --eval +``` +**注意:** +- 如果需要边训练边评估,请添加`--eval`. +- Mask RT-DETR支持混合精度训练,请添加`--amp`. +- PaddleDetection支持多机训练,可以参考[多机训练教程](../../docs/tutorials/DistributedTraining_cn.md). + +### 评估 + +执行以下命令在单个GPU上评估COCO val2017数据集 + +```bash +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams +``` + +在coco test-dev2017上评估,请先从[COCO数据集下载](https://cocodataset.org/#download)下载COCO test-dev2017数据集,然后解压到COCO数据集文件夹并像`configs/ppyolo/ppyolo_test.yml`一样配置`EvalDataset`。 + +### 推理 + +使用以下命令在单张GPU上预测图片,使用`--infer_img`推理单张图片以及使用`--infer_dir`推理文件中的所有图片。 + + +```bash +# 推理单张图片 +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams --infer_img=demo/000000014439_640x640.jpg + +# 推理文件中的所有图片 +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams --infer_dir=demo +``` + +### 模型导出 + +Mask RT-DETR在GPU上部署或者速度测试需要通过`tools/export_model.py`导出模型。 + +当你**使用Paddle Inference但不使用TensorRT**时,运行以下的命令导出模型 + +```bash +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams +``` + +当你**使用Paddle Inference且使用TensorRT**时,需要指定`-o trt=True`来导出模型。 + +```bash +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True +``` + +如果你想将PP-YOLOE模型导出为**ONNX格式**,参考 +[PaddleDetection模型导出为ONNX格式教程](../../deploy/EXPORT_ONNX_MODEL.md),运行以下命令: + +```bash + +# 导出推理模型 +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True + +# 安装paddle2onnx +pip install paddle2onnx + +# 转换成onnx格式 +paddle2onnx --model_dir output_inference/mask_rtdetr_hgnetv2_l_6x_coco --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 16 --save_file mask_rtdetr_hgnetv2_l_6x_coco.onnx +``` + +**注意:** ONNX模型目前只支持batch_size=1 + +### 速度测试 + +**使用 ONNX 和 TensorRT** 进行测速,执行以下命令: + +```bash +# 导出模型 +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True + +# 转化成ONNX格式 +paddle2onnx --model_dir output_inference/mask_rtdetr_hgnetv2_l_6x_coco --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 16 --save_file mask_rtdetr_hgnetv2_l_6x_coco.onnx + +``` + +固定先前导出的ONNX模型的`im_shape`和`scale_factor`两个输入数据,代码如下: +```python +# onnx_edit.py + +import copy + +import onnx + +if __name__ == '__main__': + model_path = './mask_rtdetr_hgnetv2_l_6x_coco.onnx' + model = onnx.load_model(model_path) + + im_shape = onnx.helper.make_tensor( + name='im_shape', + data_type=onnx.helper.TensorProto.FLOAT, + dims=[1, 2], + vals=[640, 640]) + scale_factor = onnx.helper.make_tensor( + name='scale_factor', + data_type=onnx.helper.TensorProto.FLOAT, + dims=[1, 2], + vals=[1, 1]) + + new_model = copy.deepcopy(model) + + for input in model.graph.input: + if input.name == 'im_shape': + new_model.graph.input.remove(input) + new_model.graph.initializer.append(im_shape) + + if input.name == 'scale_factor': + new_model.graph.input.remove(input) + new_model.graph.initializer.append(scale_factor) + + onnx.checker.check_model(new_model, full_check=True) + onnx.save_model(new_model, model_path) +``` + +使用onnxsim简化onnx模型: +```shell +pip install onnxsim +onnxsim mask_rtdetr_hgnetv2_l_6x_coco.onnx mask_rtdetr_hgnetv2_l_6x_coco.onnx --overwrite-input-shape "image:1,3,640,640" +``` + +```shell +# 测试速度,半精度,batch_size=1 +trtexec --onnx=./mask_rtdetr_hgnetv2_l_6x_coco.onnx --saveEngine=./mask_rtdetr_hgnetv2_l_6x_coco.engine --workspace=4096 --avgRuns=1000 --fp16 +``` + + +### 部署 + +Mask RT-DETR可以使用以下方式进行部署: + - Paddle Inference [Python](../../deploy/python) & [C++](../../deploy/cpp) + +接下来,我们将介绍Mask RT-DETR如何使用Paddle Inference在TensorRT FP16模式下部署 + +首先,参考[Paddle Inference文档](https://www.paddlepaddle.org.cn/inference/master/user_guides/download_lib.html#python),下载并安装与你的CUDA, CUDNN和TensorRT相应的wheel包。 + +然后,运行以下命令导出模型 + +```bash +python tools/export_model.py -c configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/mask_rtdetr_hgnetv2_l_6x_coco.pdparams trt=True +``` + +最后,使用TensorRT FP16进行推理 + +```bash +# 推理单张图片 +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/mask_rtdetr_hgnetv2_l_6x_coco --image_file=demo/000000014439_640x640.jpg --device=gpu --run_mode=trt_fp16 + +# 推理文件夹下的所有图片 +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/mask_rtdetr_hgnetv2_l_6x_coco --image_dir=demo/ --device=gpu --run_mode=trt_fp16 + +``` + +**注意:** +- TensorRT会根据网络的定义,执行针对当前硬件平台的优化,生成推理引擎并序列化为文件。该推理引擎只适用于当前软硬件平台。如果你的软硬件平台没有发生变化,你可以设置[enable_tensorrt_engine](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/python/infer.py#L660)的参数`use_static=True`,这样生成的序列化文件将会保存在`output_inference`文件夹下,下次执行TensorRT时将加载保存的序列化文件。 + + +## 更多用法 + +### 模型权重保存指标的设定 +在实例分割任务中,如果使用COCO风格的评测指标,且推理结果中同时包含`box`和`mask`, 那么默认的模型权重保存指标是`box ap`。如果想要将保存指标改为`mask ap`,可以在配置文件中添加`target_metrics`参数,并将其赋值为`mask`。具体的配置文件示例可以参考[mask_rtdetr_r50vd.yml](./_base_/mask_rtdetr_r50vd.yml)。 + +### Query数量的设定 + +- 后处理中Query数量的设定可以通过修改`DETRPostProcess`的`num_top_queries`参数来调整。Mask RT-DETR后处理的`num_top_queries`默认值为100。 +- Query选择部分Query数量的设定可以通过修改`MaskRTDETR`的`num_queries`参数来调整。Mask RT-DETR的`num_queries`默认值为300。由于Mask RT-DETR的分类损失考虑了mask IoU和置信度的一致性,我们可以将Mask RT-DETR的`num_queries`设置为100,并直接加载在`num_queries`为300时训练的权重。这样可以在精度损失不大的情况下,提高推理速度。 diff --git a/configs/mask_rtdetr/_base_/mask_rtdetr_r50vd.yml b/configs/mask_rtdetr/_base_/mask_rtdetr_r50vd.yml new file mode 100644 index 00000000000..435335fc10c --- /dev/null +++ b/configs/mask_rtdetr/_base_/mask_rtdetr_r50vd.yml @@ -0,0 +1,78 @@ +architecture: DETR +with_mask: True +target_metrics: mask +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +norm_type: sync_bn +use_ema: True +ema_decay: 0.9999 +ema_decay_type: "exponential" +ema_filter_no_grad: True +hidden_dim: 256 +use_focal_loss: True +eval_size: [640, 640] +num_prototypes: 32 + + +DETR: + backbone: ResNet + neck: MaskHybridEncoder + transformer: MaskRTDETR + detr_head: MaskDINOHead + post_process: DETRPostProcess + +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [0, 1, 2, 3] + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + num_stages: 4 + freeze_stem_only: True + +MaskHybridEncoder: + hidden_dim: 256 + use_encoder_idx: [3] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 1.0 + mask_feat_channels: [64, 64] + + +MaskRTDETR: + num_queries: 300 + position_embed_type: sine + feat_strides: [8, 16, 32] + num_levels: 3 + nhead: 8 + num_decoder_layers: 6 + dim_feedforward: 1024 + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: False + mask_enhanced: True + +MaskDINOHead: + loss: + name: MaskDINOLoss + loss_coeff: {class: 4, bbox: 5, giou: 2, mask: 5, dice: 5} + aux_loss: True + use_vfl: True + vfl_iou_type: 'mask' + matcher: + name: HungarianMatcher + matcher_coeff: {class: 4, bbox: 5, giou: 2, mask: 5, dice: 5} + +DETRPostProcess: + num_top_queries: 100 + mask_stride: 4 diff --git a/configs/mask_rtdetr/_base_/mask_rtdetr_reader.yml b/configs/mask_rtdetr/_base_/mask_rtdetr_reader.yml new file mode 100644 index 00000000000..4ec92a7dcd3 --- /dev/null +++ b/configs/mask_rtdetr/_base_/mask_rtdetr_reader.yml @@ -0,0 +1,44 @@ +worker_num: 4 +TrainReader: + sample_transforms: + - Decode: {} + - Poly2Mask: {del_poly: True} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + batch_size: 4 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: true + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 # must be 1 + shuffle: false + drop_last: false + + +TestReader: + inputs_def: + image_shape: [3, 640, 640] + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/mask_rtdetr/_base_/optimizer_6x.yml b/configs/mask_rtdetr/_base_/optimizer_6x.yml new file mode 100644 index 00000000000..5abe2f75a2c --- /dev/null +++ b/configs/mask_rtdetr/_base_/optimizer_6x.yml @@ -0,0 +1,19 @@ +epoch: 72 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 1.0 + milestones: [100] + use_warmup: true + - !LinearWarmup + start_factor: 0.001 + steps: 2000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml b/configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml new file mode 100644 index 00000000000..08dda15d7a4 --- /dev/null +++ b/configs/mask_rtdetr/mask_rtdetr_hgnetv2_l_6x_coco.yml @@ -0,0 +1,24 @@ +_BASE_: [ + '../datasets/coco_instance.yml', + '../runtime.yml', + '_base_/optimizer_6x.yml', + '_base_/mask_rtdetr_r50vd.yml', + '_base_/mask_rtdetr_reader.yml', +] + +weights: output/mask_rtdetr_hgnetv2_l_6x_coco/model_final +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/PPHGNetV2_L_ssld_pretrained.pdparams +find_unused_parameters: True +log_iter: 200 +save_dir: output/mask_rtdetr_hgnetv2_l_6x_coco + +DETR: + backbone: PPHGNetV2 + +PPHGNetV2: + arch: 'L' + return_idx: [0, 1, 2, 3] + freeze_stem_only: True + freeze_at: 0 + freeze_norm: True + lr_mult_list: [0., 0.05, 0.05, 0.05, 0.05] diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index f1ea70243a3..113aff0803e 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -1248,6 +1248,14 @@ def __call__(self, samples, context=None): if num_gt > 0: pad_gt_areas[:num_gt, 0] = sample['gt_areas'] sample['gt_areas'] = pad_gt_areas + # gt_segm + if 'gt_segm' in sample: + pad_gt_segm = np.zeros( + (num_max_boxes, *sample['gt_segm'].shape[-2:]), + dtype=np.uint8) + if num_gt > 0: + pad_gt_segm[:num_gt] = sample['gt_segm'] + sample['gt_segm'] = pad_gt_segm.astype(np.float32) return samples diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 76a5fdd0ada..3bd4320575e 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -2354,6 +2354,18 @@ def apply(self, sample, context=None): sample['gt_keypoint'] = self.apply_keypoint(sample['gt_keypoint'], offsets) + if 'gt_segm' in sample and len(sample['gt_segm']) > 0: + masks = [ + cv2.copyMakeBorder( + gt_segm, + offset_y, h - (offset_y + im_h), + offset_x, w - (offset_x + im_w), + borderType=cv2.BORDER_CONSTANT, + value=0) + for gt_segm in sample['gt_segm'] + ] + sample['gt_segm'] = np.asarray(masks, dtype=np.uint8) + return sample diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index ac2f330f519..861b9c53e99 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -195,6 +195,9 @@ def on_epoch_end(self, status): key = 'keypoint' else: key = 'mask' + + key = self.model.cfg.get('target_metrics', key) + if key not in map_res: logger.warning("Evaluation results empty, this may be due to " \ "training iterations being too few or not " \ diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index daaa39a6294..0a6d68305b4 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -312,6 +312,9 @@ def _dump_infer_config(config, path, image_shape, model): infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch] arch_state = True + if infer_arch == 'DETR' and config.get('with_mask', False): + infer_cfg['mask'] = True + if not arch_state: logger.error( 'Architecture: {} is not supported for exporting model now.\n'. diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 0d126c08fc3..e2b2dc2da0e 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -71,4 +71,4 @@ from .sparse_roi_head import * from .petr_head import * from .vitpose_head import * -from .clrnet_head import * \ No newline at end of file +from .clrnet_head import * diff --git a/ppdet/modeling/heads/detr_head.py b/ppdet/modeling/heads/detr_head.py index d3c093fbc40..50ed14ae613 100644 --- a/ppdet/modeling/heads/detr_head.py +++ b/ppdet/modeling/heads/detr_head.py @@ -521,6 +521,9 @@ def forward(self, out_transformer, body_feats, inputs=None): out_masks = paddle.concat( [enc_out_masks.unsqueeze(0), dec_out_masks]) + inputs['gt_segm'] = [gt_segm.astype(out_masks.dtype) + for gt_segm in inputs['gt_segm']] + return self.loss( out_bboxes, out_logits, diff --git a/ppdet/modeling/losses/detr_loss.py b/ppdet/modeling/losses/detr_loss.py index d635337bcc5..6712f928d53 100644 --- a/ppdet/modeling/losses/detr_loss.py +++ b/ppdet/modeling/losses/detr_loss.py @@ -46,6 +46,7 @@ def __init__(self, aux_loss=True, use_focal_loss=False, use_vfl=False, + vfl_iou_type='bbox', use_uni_match=False, uni_match_ind=0): r""" @@ -65,6 +66,7 @@ def __init__(self, self.aux_loss = aux_loss self.use_focal_loss = use_focal_loss self.use_vfl = use_vfl + self.vfl_iou_type = vfl_iou_type self.use_uni_match = use_uni_match self.uni_match_ind = uni_match_ind @@ -329,11 +331,41 @@ def _get_prediction_loss(self, _, target_score = self._get_src_target_assign( logits[-1].detach(), gt_score, match_indices) elif sum(len(a) for a in gt_bbox) > 0: - src_bbox, target_bbox = self._get_src_target_assign( - boxes.detach(), gt_bbox, match_indices) - iou_score = bbox_iou( - bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), - bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) + if self.vfl_iou_type == 'bbox': + src_bbox, target_bbox = self._get_src_target_assign( + boxes.detach(), gt_bbox, match_indices) + iou_score = bbox_iou( + bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), + bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) + elif self.vfl_iou_type == 'mask': + assert (masks is not None and gt_mask is not None, + 'Make sure the input has `mask` and `gt_mask`') + assert sum(len(a) for a in gt_mask) > 0 + src_mask, target_mask = self._get_src_target_assign( + masks.detach(), gt_mask, match_indices) + src_mask = F.interpolate( + src_mask.unsqueeze(0), + scale_factor=2, + mode='bilinear', + align_corners=False).squeeze(0) + target_mask = F.interpolate( + target_mask.unsqueeze(0), + size=src_mask.shape[-2:], + mode='bilinear', + align_corners=False).squeeze(0) + src_mask = src_mask.flatten(1) + src_mask = F.sigmoid(src_mask) + src_mask = paddle.where( + src_mask > 0.5, 1., 0.).astype(masks.dtype) + target_mask = target_mask.flatten(1) + target_mask = paddle.where( + target_mask > 0.5, 1., 0.).astype(masks.dtype) + inter = (src_mask * target_mask).sum(1) + union = src_mask.sum(1) + target_mask.sum(1) - inter + iou_score = (inter + 1e-2) / (union + 1e-2) + iou_score = iou_score.unsqueeze(-1) + else: + iou_score = None else: iou_score = None else: @@ -501,11 +533,13 @@ def __init__(self, }, aux_loss=True, use_focal_loss=False, + use_vfl=False, + vfl_iou_type='bbox', num_sample_points=12544, oversample_ratio=3.0, important_sample_ratio=0.75): super(MaskDINOLoss, self).__init__(num_classes, matcher, loss_coeff, - aux_loss, use_focal_loss) + aux_loss, use_focal_loss, use_vfl, vfl_iou_type) assert oversample_ratio >= 1 assert important_sample_ratio <= 1 and important_sample_ratio >= 0 @@ -628,4 +662,4 @@ def _get_point_coords_by_uncertainty(self, masks): paddle.rand([num_masks, self.num_random_points, 2]) ], axis=1) - return sample_points + return sample_points \ No newline at end of file diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index efde830b1fe..298ffbceb16 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -455,6 +455,7 @@ def __init__(self, dual_groups=0, use_focal_loss=False, with_mask=False, + mask_stride=4, mask_threshold=0.5, use_avg_mask_score=False, bbox_decode_type='origin'): @@ -467,19 +468,20 @@ def __init__(self, self.dual_groups = dual_groups self.use_focal_loss = use_focal_loss self.with_mask = with_mask + self.mask_stride = mask_stride self.mask_threshold = mask_threshold self.use_avg_mask_score = use_avg_mask_score self.bbox_decode_type = bbox_decode_type - def _mask_postprocess(self, mask_pred, score_pred, index): - mask_score = F.sigmoid(paddle.gather_nd(mask_pred, index)) + def _mask_postprocess(self, mask_pred, score_pred): + mask_score = F.sigmoid(mask_pred) mask_pred = (mask_score > self.mask_threshold).astype(mask_score.dtype) if self.use_avg_mask_score: avg_mask_score = (mask_pred * mask_score).sum([-2, -1]) / ( mask_pred.sum([-2, -1]) + 1e-6) score_pred *= avg_mask_score - return mask_pred[0].astype('int32'), score_pred + return mask_pred.flatten(0, 1).astype('int32'), score_pred def __call__(self, head_out, im_shape, scale_factor, pad_shape): """ @@ -545,21 +547,27 @@ def __call__(self, head_out, im_shape, scale_factor, pad_shape): mask_pred = None if self.with_mask: assert masks is not None - masks = F.interpolate( - masks, scale_factor=4, mode="bilinear", align_corners=False) - # TODO: Support prediction with bs>1. - # remove padding for input image - h, w = im_shape.astype('int32')[0] - masks = masks[..., :h, :w] + assert masks.shape[0] == 1 + masks = paddle.gather_nd(masks, index) + if self.bbox_decode_type == 'pad': + masks = F.interpolate( + masks, + scale_factor=self.mask_stride, + mode="bilinear", + align_corners=False) + # TODO: Support prediction with bs>1. + # remove padding for input image + h, w = im_shape.astype('int32')[0] + masks = masks[..., :h, :w] # get pred_mask in the original resolution. img_h = img_h[0].astype('int32') img_w = img_w[0].astype('int32') masks = F.interpolate( masks, - size=(img_h, img_w), + size=[img_h, img_w], mode="bilinear", align_corners=False) - mask_pred, scores = self._mask_postprocess(masks, scores, index) + mask_pred, scores = self._mask_postprocess(masks, scores) bbox_pred = paddle.concat( [ @@ -798,4 +806,4 @@ def __call__(self, head_out): bbox_num = paddle.to_tensor( bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]]) bbox_pred = bbox_pred.reshape([-1, bbox_pred.shape[-1]]) - return bbox_pred, bbox_num + return bbox_pred, bbox_num \ No newline at end of file diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py index 33a12402656..5eac4f110de 100644 --- a/ppdet/modeling/transformers/__init__.py +++ b/ppdet/modeling/transformers/__init__.py @@ -22,6 +22,7 @@ from . import mask_dino_transformer from . import rtdetr_transformer from . import hybrid_encoder +from . import mask_rtdetr_transformer from .detr_transformer import * from .utils import * @@ -34,3 +35,4 @@ from .mask_dino_transformer import * from .rtdetr_transformer import * from .hybrid_encoder import * +from .mask_rtdetr_transformer import * diff --git a/ppdet/modeling/transformers/hybrid_encoder.py b/ppdet/modeling/transformers/hybrid_encoder.py index 5694803ebe0..9038e845c03 100644 --- a/ppdet/modeling/transformers/hybrid_encoder.py +++ b/ppdet/modeling/transformers/hybrid_encoder.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -26,7 +26,7 @@ from paddle import ParamAttr from paddle.regularizer import L2Decay -__all__ = ['HybridEncoder'] +__all__ = ['HybridEncoder', 'MaskHybridEncoder'] class CSPRepLayer(nn.Layer): @@ -299,3 +299,138 @@ def out_shape(self): channels=self.hidden_dim, stride=self.feat_strides[idx]) for idx in range(len(self.in_channels)) ] + + +class MaskFeatFPN(nn.Layer): + def __init__(self, + in_channels=[256, 256, 256], + fpn_strides=[32, 16, 8], + feat_channels=256, + dropout_ratio=0.0, + out_channels=256, + align_corners=False, + act='swish'): + super(MaskFeatFPN, self).__init__() + assert len(in_channels) == len(fpn_strides) + reorder_index = np.argsort(fpn_strides, axis=0) + in_channels = [in_channels[i] for i in reorder_index] + fpn_strides = [fpn_strides[i] for i in reorder_index] + assert min(fpn_strides) == fpn_strides[0] + self.reorder_index = reorder_index + self.fpn_strides = fpn_strides + self.dropout_ratio = dropout_ratio + self.align_corners = align_corners + if self.dropout_ratio > 0: + self.dropout = nn.Dropout2D(dropout_ratio) + + self.scale_heads = nn.LayerList() + for i in range(len(fpn_strides)): + head_length = max( + 1, int(np.log2(fpn_strides[i]) - np.log2(fpn_strides[0]))) + scale_head = [] + for k in range(head_length): + in_c = in_channels[i] if k == 0 else feat_channels + scale_head.append( + nn.Sequential( + BaseConv(in_c, feat_channels, 3, 1, act=act)) + ) + if fpn_strides[i] != fpn_strides[0]: + scale_head.append( + nn.Upsample( + scale_factor=2, + mode='bilinear', + align_corners=align_corners)) + + self.scale_heads.append(nn.Sequential(*scale_head)) + + self.output_conv = BaseConv( + feat_channels, out_channels, 3, 1, act=act) + + def forward(self, inputs): + x = [inputs[i] for i in self.reorder_index] + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.fpn_strides)): + output = output + F.interpolate( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + if self.dropout_ratio > 0: + output = self.dropout(output) + output = self.output_conv(output) + return output + + +@register +@serializable +class MaskHybridEncoder(HybridEncoder): + __shared__ = ['depth_mult', 'act', 'trt', 'eval_size', 'num_prototypes'] + __inject__ = ['encoder_layer'] + + def __init__(self, + in_channels=[256, 512, 1024, 2048], + feat_strides=[4, 8, 16, 32], + hidden_dim=256, + use_encoder_idx=[3], + num_encoder_layers=1, + encoder_layer='TransformerLayer', + num_prototypes=32, + pe_temperature=10000, + expansion=1.0, + depth_mult=1.0, + mask_feat_channels=[64, 64], + act='silu', + trt=False, + eval_size=None): + assert len(in_channels) == len(feat_strides) + x4_feat_dim = in_channels.pop(0) + x4_feat_stride = feat_strides.pop(0) + use_encoder_idx = [i - 1 for i in use_encoder_idx] + assert x4_feat_stride == 4 + + super(MaskHybridEncoder, self).__init__( + in_channels=in_channels, + feat_strides=feat_strides, + hidden_dim=hidden_dim, + use_encoder_idx=use_encoder_idx, + num_encoder_layers=num_encoder_layers, + encoder_layer=encoder_layer, + pe_temperature=pe_temperature, + expansion=expansion, + depth_mult=depth_mult, + act=act, + trt=trt, + eval_size=eval_size) + + self.mask_feat_head = MaskFeatFPN( + [hidden_dim] * len(feat_strides), + feat_strides, + feat_channels=mask_feat_channels[0], + out_channels=mask_feat_channels[1], + act=act) + self.enc_mask_lateral = BaseConv( + x4_feat_dim, mask_feat_channels[1], 3, 1, act=act) + self.enc_mask_output = nn.Sequential( + BaseConv( + mask_feat_channels[1], + mask_feat_channels[1], 3, 1, act=act), + nn.Conv2D(mask_feat_channels[1], num_prototypes, 1)) + + def forward(self, feats, for_mot=False, is_teacher=False): + x4_feat = feats.pop(0) + + enc_feats = super(MaskHybridEncoder, self).forward( + feats, for_mot=for_mot, is_teacher=is_teacher) + + mask_feat = self.mask_feat_head(enc_feats) + mask_feat = F.interpolate( + mask_feat, + scale_factor=2, + mode='bilinear', + align_corners=False) + mask_feat += self.enc_mask_lateral(x4_feat) + mask_feat = self.enc_mask_output(mask_feat) + + return enc_feats, mask_feat diff --git a/ppdet/modeling/transformers/mask_rtdetr_transformer.py b/ppdet/modeling/transformers/mask_rtdetr_transformer.py new file mode 100644 index 00000000000..ebdac9c1f03 --- /dev/null +++ b/ppdet/modeling/transformers/mask_rtdetr_transformer.py @@ -0,0 +1,454 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay + +from ppdet.core.workspace import register +from .rtdetr_transformer import TransformerDecoderLayer +from .utils import (_get_clones, inverse_sigmoid, get_denoising_training_group, + mask_to_box_coordinate) +from ..heads.detr_head import MLP +from ..initializer import (linear_init_, constant_, xavier_uniform_, bias_init_with_prob) + +__all__ = ['MaskRTDETR'] + + +def _get_pred_class_and_mask(query_embed, + mask_feat, + dec_norm, + score_head, + mask_query_head): + out_query = dec_norm(query_embed) + out_logits = score_head(out_query) + mask_query_embed = mask_query_head(out_query) + batch_size, mask_dim, _ = mask_query_embed.shape + _, _, mask_h, mask_w = mask_feat.shape + out_mask = paddle.bmm( + mask_query_embed, mask_feat.flatten(2)).reshape( + [batch_size, mask_dim, mask_h, mask_w]) + return out_logits, out_mask + + +class MaskTransformerDecoder(nn.Layer): + def __init__(self, + hidden_dim, + decoder_layer, + num_layers, + eval_idx=-1): + super(MaskTransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.eval_idx = eval_idx if eval_idx >= 0 \ + else num_layers + eval_idx + + def forward(self, + tgt, + ref_points_unact, + memory, + memory_spatial_shapes, + memory_level_start_index, + mask_feat, + bbox_head, + score_head, + query_pos_head, + mask_query_head, + dec_norm, + attn_mask=None, + memory_mask=None, + query_pos_head_inv_sig=False): + output = tgt + dec_out_bboxes = [] + dec_out_logits = [] + dec_out_masks = [] + ref_points_detach = F.sigmoid(ref_points_unact) + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + if not query_pos_head_inv_sig: + query_pos_embed = query_pos_head(ref_points_detach) + else: + query_pos_embed = query_pos_head( + inverse_sigmoid(ref_points_detach)) + + output = layer(output, ref_points_input, memory, + memory_spatial_shapes, memory_level_start_index, + attn_mask, memory_mask, query_pos_embed) + + inter_ref_bbox = F.sigmoid(bbox_head(output) + + inverse_sigmoid(ref_points_detach)) + + if self.training: + logits_, masks_ = _get_pred_class_and_mask( + output, mask_feat, dec_norm, + score_head, mask_query_head) + dec_out_logits.append(logits_) + dec_out_masks.append(masks_) + if i == 0: + dec_out_bboxes.append(inter_ref_bbox) + else: + dec_out_bboxes.append( + F.sigmoid(bbox_head(output) + + inverse_sigmoid(ref_points))) + elif i == self.eval_idx: + logits_, masks_ = _get_pred_class_and_mask( + output, mask_feat, dec_norm, + score_head, mask_query_head) + dec_out_logits.append(logits_) + dec_out_masks.append(masks_) + dec_out_bboxes.append(inter_ref_bbox) + return (paddle.stack(dec_out_bboxes), + paddle.stack(dec_out_logits), + paddle.stack(dec_out_masks)) + + ref_points = inter_ref_bbox + ref_points_detach = inter_ref_bbox.detach( + ) if self.training else inter_ref_bbox + + return (paddle.stack(dec_out_bboxes), + paddle.stack(dec_out_logits), + paddle.stack(dec_out_masks)) + + +@register +class MaskRTDETR(nn.Layer): + __shared__ = ['num_classes', 'hidden_dim', 'eval_size', 'num_prototypes'] + + def __init__(self, + num_classes=80, + hidden_dim=256, + num_queries=300, + position_embed_type='sine', + backbone_feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_prototypes=32, + num_levels=3, + num_decoder_points=4, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0., + activation="relu", + num_denoising=100, + label_noise_ratio=0.4, + box_noise_scale=0.4, + learnt_init_query=False, + query_pos_head_inv_sig=False, + mask_enhanced=True, + eval_size=None, + eval_idx=-1, + eps=1e-2): + super(MaskRTDETR, self).__init__() + assert position_embed_type in ['sine', 'learned'], \ + f'ValueError: position_embed_type not supported {position_embed_type}!' + assert len(backbone_feat_channels) <= num_levels + assert len(feat_strides) == len(backbone_feat_channels) + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_decoder_layers = num_decoder_layers + self.mask_enhanced = mask_enhanced + self.eval_size = eval_size + + # backbone feature projection + self._build_input_proj_layer(backbone_feat_channels) + + # Transformer module + decoder_layer = TransformerDecoderLayer( + hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, + num_decoder_points) + self.decoder = MaskTransformerDecoder(hidden_dim, decoder_layer, + num_decoder_layers, eval_idx) + + # denoising part + self.denoising_class_embed = nn.Embedding( + num_classes, + hidden_dim, + weight_attr=ParamAttr(initializer=nn.initializer.Normal())) + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + + # decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_head = MLP(4, 2 * hidden_dim, + hidden_dim, num_layers=2) + self.query_pos_head_inv_sig = query_pos_head_inv_sig + + # mask embedding + self.mask_query_head = MLP(hidden_dim, hidden_dim, + num_prototypes, num_layers=3) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))) + + # decoder norm layer + self.dec_norm = nn.LayerNorm( + hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # shared prediction head + self.score_head = nn.Linear(hidden_dim, num_classes) + self.bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) + + self._reset_parameters() + + def _reset_parameters(self): + # class and bbox head init + bias_cls = bias_init_with_prob(0.01) + linear_init_(self.score_head) + constant_(self.score_head.bias, bias_cls) + constant_(self.bbox_head.layers[-1].weight) + constant_(self.bbox_head.layers[-1].bias) + + linear_init_(self.enc_output[0]) + xavier_uniform_(self.enc_output[0].weight) + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + for l in self.input_proj: + xavier_uniform_(l[0].weight) + + # init encoder output anchors and valid_mask + if self.eval_size: + self.anchors, self.valid_mask = self._generate_anchors() + + @classmethod + def from_config(cls, cfg, input_shape): + return {'backbone_feat_channels': [i.channels for i in input_shape], + 'feat_strides': [i.stride for i in input_shape]} + + def _build_input_proj_layer(self, backbone_feat_channels): + self.input_proj = nn.LayerList() + for in_channels in backbone_feat_channels: + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=1, + bias_attr=False)), + ('norm', nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = backbone_feat_channels[-1] + for _ in range(self.num_levels - len(backbone_feat_channels)): + self.input_proj.append( + nn.Sequential( + ('conv', nn.Conv2D( + in_channels, + self.hidden_dim, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False)), + ('norm', nn.BatchNorm2D( + self.hidden_dim, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0)))))) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats): + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + level_start_index = [0, ] + for i, feat in enumerate(proj_feats): + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).transpose([0, 2, 1])) + # [num_levels, 2] + spatial_shapes.append([h, w]) + # [l], start index of each level + level_start_index.append(h * w + level_start_index[-1]) + + # [b, l, c] + feat_flatten = paddle.concat(feat_flatten, 1) + level_start_index.pop() + return feat_flatten, spatial_shapes, level_start_index + + def forward(self, feats, pad_mask=None, gt_meta=None, is_teacher=False): + enc_feats, mask_feat = feats + # input projection and embedding + (memory, spatial_shapes, + level_start_index) = self._get_encoder_input(enc_feats) + + # prepare denoising training + if self.training: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \ + get_denoising_training_group(gt_meta, + self.num_classes, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale) + else: + denoising_class, denoising_bbox_unact,\ + attn_mask, dn_meta = None, None, None, None + + target, init_ref_points_unact, enc_out, init_out = \ + self._get_decoder_input( + memory, mask_feat, spatial_shapes, + denoising_class, denoising_bbox_unact, is_teacher) + + # decoder + out_bboxes, out_logits, out_masks = self.decoder( + target, + init_ref_points_unact, + memory, + spatial_shapes, + level_start_index, + mask_feat, + self.bbox_head, + self.score_head, + self.query_pos_head, + self.mask_query_head, + self.dec_norm, + attn_mask=attn_mask, + memory_mask=None, + query_pos_head_inv_sig=self.query_pos_head_inv_sig) + + return out_logits, out_bboxes, out_masks, enc_out, init_out, dn_meta + + def _generate_anchors(self, + spatial_shapes=None, + grid_size=0.05, + dtype=paddle.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.eval_size[0] / s), int(self.eval_size[1] / s)] + for s in self.feat_strides + ] + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = paddle.meshgrid( + paddle.arange( + end=h, dtype=dtype), + paddle.arange( + end=w, dtype=dtype)) + grid_xy = paddle.stack([grid_x, grid_y], -1) + + valid_WH = paddle.to_tensor([h, w]).astype(dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH + wh = paddle.ones_like(grid_xy) * grid_size * (2.0 ** lvl) + anchors.append( + paddle.concat([grid_xy, wh], -1).reshape([-1, h * w, 4])) + + anchors = paddle.concat(anchors, 1) + valid_mask = ((anchors > self.eps) * + (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = paddle.log(anchors / (1 - anchors)) + anchors = paddle.where(valid_mask, anchors, + paddle.to_tensor(float("inf"))) + return anchors, valid_mask + + def _get_decoder_input(self, + memory, + mask_feat, + spatial_shapes, + denoising_class=None, + denoising_bbox_unact=None, + is_teacher=False): + bs, _, _ = memory.shape + # prepare input for decoder + if self.training or self.eval_size is None or is_teacher: + anchors, valid_mask = self._generate_anchors(spatial_shapes) + else: + anchors, valid_mask = self.anchors, self.valid_mask + memory = paddle.where(valid_mask, memory, paddle.to_tensor(0.)) + output_memory = self.enc_output(memory) + + enc_logits_unact = self.score_head(output_memory) + enc_bboxes_unact = self.bbox_head(output_memory) + anchors + + # get topk index + _, topk_ind = paddle.topk( + enc_logits_unact.max(-1), self.num_queries, axis=1) + batch_ind = paddle.arange(end=bs).astype(topk_ind.dtype) + batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries]) + topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1) + + # extract content and position query embedding + target = paddle.gather_nd(output_memory, topk_ind) + reference_points_unact = paddle.gather_nd(enc_bboxes_unact, + topk_ind) # unsigmoided. + # get encoder output: {logits, bboxes, masks} + enc_out_logits, enc_out_masks = _get_pred_class_and_mask( + target, mask_feat, self.dec_norm, + self.score_head, self.mask_query_head) + enc_out_bboxes = F.sigmoid(reference_points_unact) + enc_out = (enc_out_logits, enc_out_bboxes, enc_out_masks) + + # concat denoising query + if self.learnt_init_query: + target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + else: + target = target.detach() + if denoising_class is not None: + target = paddle.concat([denoising_class, target], 1) + if self.mask_enhanced: + # use mask-enhanced anchor box initialization + reference_points = mask_to_box_coordinate( + enc_out_masks > 0, normalize=True, format="xywh") + reference_points_unact = inverse_sigmoid(reference_points) + if denoising_bbox_unact is not None: + reference_points_unact = paddle.concat( + [denoising_bbox_unact, reference_points_unact], 1) + + # direct prediction from the matching and denoising part in the beginning + if self.training and denoising_class is not None: + init_out_logits, init_out_masks = _get_pred_class_and_mask( + target, mask_feat, self.dec_norm, + self.score_head, self.mask_query_head) + init_out_bboxes = F.sigmoid(reference_points_unact) + init_out = (init_out_logits, init_out_bboxes, init_out_masks) + else: + init_out = None + + return target, reference_points_unact.detach(), enc_out, init_out diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index a6f211a78f2..fa60b711172 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -372,8 +372,6 @@ def mask_to_box_coordinate(mask, """ assert mask.ndim == 4 assert format in ["xyxy", "xywh"] - if mask.sum() == 0: - return paddle.zeros([mask.shape[0], mask.shape[1], 4], dtype=dtype) h, w = mask.shape[-2:] y, x = paddle.meshgrid( @@ -391,6 +389,7 @@ def mask_to_box_coordinate(mask, y_min = paddle.where(mask, y_mask, paddle.to_tensor(1e8)).flatten(-2).min(-1) out_bbox = paddle.stack([x_min, y_min, x_max, y_max], axis=-1) + out_bbox *= mask.any(axis=[2, 3]).unsqueeze(2) if normalize: out_bbox /= paddle.to_tensor([w, h, w, h]).astype(dtype)