forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tutorial of QAT for classification (PaddlePaddle#1716)
- Loading branch information
1 parent
99cd470
commit 100b7e1
Showing
3 changed files
with
432 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# 动态图量化训练 | ||
|
||
本示例介绍如何对动态图模型进行量化训练,示例以常用的MobileNetV1,介绍如何对其进行量化训练。 | ||
|
||
|
||
## 分类模型的量化训练流程 | ||
|
||
### 准备数据 | ||
|
||
在当前目录下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件: | ||
- ``'train'``文件夹,训练图片 | ||
- ``'train_list.txt'``文件 | ||
- ``'val'``文件夹,验证图片 | ||
- ``'val_list.txt'``文件 | ||
|
||
### 准备需要量化的模型 | ||
|
||
本示例直接使用[paddle vision](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/vision/models)内置的模型结构和预训练权重。通过以下命令查看支持的所有模型: | ||
|
||
``` | ||
python train.py --help | ||
``` | ||
|
||
### 训练命令 | ||
|
||
- MobileNetV1 | ||
|
||
我们使用普通的量化训练方法即可,启动命令如下: | ||
|
||
```bash | ||
# 单卡训练 | ||
python train.py --model=mobilenet_v1 | ||
# 多卡训练,以0到3号卡为例 | ||
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model=mobilenet_v1 | ||
``` | ||
|
||
### 量化结果 | ||
|
||
| 模型 | FP32模型准确率(Top1/Top5) | 量化方法 | 量化模型准确率(Top1/Top5) | | ||
| ----------- | --------------------------- | ------------ | --------------------------- | | ||
| MobileNetV1 | 70.99/89.65 | PACT在线量化 | 70.63/89.65 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import argparse | ||
import six | ||
from inspect import isfunction | ||
from types import FunctionType | ||
from typing import Dict | ||
import paddle.vision.models as models | ||
|
||
SUPPORT_MODELS: Dict[str, FunctionType] = {} | ||
for _name, _module in models.__dict__.items(): | ||
if isfunction(_module) and 'pretrained' in _module.__code__.co_varnames: | ||
SUPPORT_MODELS[_name] = _module | ||
|
||
|
||
def parse_args(): | ||
parser = create_argparse() | ||
args = parser.parse_args() | ||
print("----------- Configuration Arguments -----------") | ||
for arg, value in sorted(six.iteritems(vars(args))): | ||
print("%s: %s" % (arg, value)) | ||
print("------------------------------------------------") | ||
return args | ||
|
||
|
||
def create_argparse(): | ||
parser = argparse.ArgumentParser("Quantization on ImageNet") | ||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=128, | ||
help="Single Card Minibatch size.", ) | ||
|
||
parser.add_argument( | ||
"--pretrained_model", | ||
type=str, | ||
default=None, | ||
help="Whether to use pretrained model.") | ||
|
||
parser.add_argument( | ||
"--use_gpu", | ||
type=bool, | ||
default=True, | ||
help="Whether to use GPU or not.", ) | ||
parser.add_argument( | ||
"--model", type=str, default="mobilenet_v1", help="The target model.") | ||
parser.add_argument( | ||
"--lr", | ||
type=float, | ||
default=0.0001, | ||
help="The learning rate used to fine-tune pruned model.") | ||
parser.add_argument( | ||
"--lr_strategy", | ||
type=str, | ||
default="piecewise_decay", | ||
help="The learning rate decay strategy.") | ||
parser.add_argument( | ||
"--l2_decay", type=float, default=3e-5, help="The l2_decay parameter.") | ||
parser.add_argument( | ||
"--ls_epsilon", type=float, default=0.0, help="Label smooth epsilon.") | ||
parser.add_argument( | ||
"--use_pact", | ||
type=bool, | ||
default=False, | ||
help="Whether to use PACT method.") | ||
parser.add_argument( | ||
"--ce_test", type=bool, default=False, help="Whether to CE test.") | ||
parser.add_argument( | ||
"--onnx_format", | ||
type=bool, | ||
default=False, | ||
help="Whether to export the quantized model with format of ONNX.") | ||
parser.add_argument( | ||
"--momentum_rate", | ||
type=float, | ||
default=0.9, | ||
help="The value of momentum_rate.") | ||
parser.add_argument( | ||
"--num_epochs", | ||
type=int, | ||
default=10, | ||
help="The number of total epochs.") | ||
parser.add_argument( | ||
"--total_images", | ||
type=int, | ||
default=1281167, | ||
help="The number of total training images.") | ||
parser.add_argument( | ||
"--data", | ||
type=str, | ||
default="imagenet", | ||
help="Which data to use. 'cifar10' or 'imagenet'") | ||
parser.add_argument( | ||
"--log_period", type=int, default=10, help="Log period in batches.") | ||
parser.add_argument( | ||
"--infer_model", | ||
type=str, | ||
default="./infer_model/int8_infer", | ||
help="inference model saved directory.") | ||
|
||
parser.add_argument( | ||
"--checkpoints", | ||
type=str, | ||
default="./checkpoints", | ||
help="checkpoints directory.") | ||
|
||
parser.add_argument( | ||
"--step_epochs", | ||
nargs="+", | ||
type=int, | ||
default=[10, 20, 30], | ||
help="piecewise decay step") | ||
return parser |
Oops, something went wrong.