Skip to content

Commit

Permalink
Add tutorial of QAT for classification (PaddlePaddle#1716)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang authored Apr 6, 2023
1 parent 99cd470 commit 100b7e1
Show file tree
Hide file tree
Showing 3 changed files with 432 additions and 0 deletions.
41 changes: 41 additions & 0 deletions example/quantization/qat/classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# 动态图量化训练

本示例介绍如何对动态图模型进行量化训练,示例以常用的MobileNetV1,介绍如何对其进行量化训练。


## 分类模型的量化训练流程

### 准备数据

在当前目录下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件

### 准备需要量化的模型

本示例直接使用[paddle vision](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/vision/models)内置的模型结构和预训练权重。通过以下命令查看支持的所有模型:

```
python train.py --help
```

### 训练命令

- MobileNetV1

我们使用普通的量化训练方法即可,启动命令如下:

```bash
# 单卡训练
python train.py --model=mobilenet_v1
# 多卡训练,以0到3号卡为例
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model=mobilenet_v1
```

### 量化结果

| 模型 | FP32模型准确率(Top1/Top5) | 量化方法 | 量化模型准确率(Top1/Top5) |
| ----------- | --------------------------- | ------------ | --------------------------- |
| MobileNetV1 | 70.99/89.65 | PACT在线量化 | 70.63/89.65 |
111 changes: 111 additions & 0 deletions example/quantization/qat/classification/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import argparse
import six
from inspect import isfunction
from types import FunctionType
from typing import Dict
import paddle.vision.models as models

SUPPORT_MODELS: Dict[str, FunctionType] = {}
for _name, _module in models.__dict__.items():
if isfunction(_module) and 'pretrained' in _module.__code__.co_varnames:
SUPPORT_MODELS[_name] = _module


def parse_args():
parser = create_argparse()
args = parser.parse_args()
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
return args


def create_argparse():
parser = argparse.ArgumentParser("Quantization on ImageNet")
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="Single Card Minibatch size.", )

parser.add_argument(
"--pretrained_model",
type=str,
default=None,
help="Whether to use pretrained model.")

parser.add_argument(
"--use_gpu",
type=bool,
default=True,
help="Whether to use GPU or not.", )
parser.add_argument(
"--model", type=str, default="mobilenet_v1", help="The target model.")
parser.add_argument(
"--lr",
type=float,
default=0.0001,
help="The learning rate used to fine-tune pruned model.")
parser.add_argument(
"--lr_strategy",
type=str,
default="piecewise_decay",
help="The learning rate decay strategy.")
parser.add_argument(
"--l2_decay", type=float, default=3e-5, help="The l2_decay parameter.")
parser.add_argument(
"--ls_epsilon", type=float, default=0.0, help="Label smooth epsilon.")
parser.add_argument(
"--use_pact",
type=bool,
default=False,
help="Whether to use PACT method.")
parser.add_argument(
"--ce_test", type=bool, default=False, help="Whether to CE test.")
parser.add_argument(
"--onnx_format",
type=bool,
default=False,
help="Whether to export the quantized model with format of ONNX.")
parser.add_argument(
"--momentum_rate",
type=float,
default=0.9,
help="The value of momentum_rate.")
parser.add_argument(
"--num_epochs",
type=int,
default=10,
help="The number of total epochs.")
parser.add_argument(
"--total_images",
type=int,
default=1281167,
help="The number of total training images.")
parser.add_argument(
"--data",
type=str,
default="imagenet",
help="Which data to use. 'cifar10' or 'imagenet'")
parser.add_argument(
"--log_period", type=int, default=10, help="Log period in batches.")
parser.add_argument(
"--infer_model",
type=str,
default="./infer_model/int8_infer",
help="inference model saved directory.")

parser.add_argument(
"--checkpoints",
type=str,
default="./checkpoints",
help="checkpoints directory.")

parser.add_argument(
"--step_epochs",
nargs="+",
type=int,
default=[10, 20, 30],
help="piecewise decay step")
return parser
Loading

0 comments on commit 100b7e1

Please sign in to comment.