Skip to content

Commit

Permalink
support pact demo load checkpoints (PaddlePaddle#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
baiyfbupt committed Aug 10, 2020
1 parent 7c8ba91 commit a45431c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
4 changes: 2 additions & 2 deletions demo/quant/pact_quant_aware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel(

普通量化:
```
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact False
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --num_epochs 30 --lr 0.0001 --use_pact False
```

Expand All @@ -179,7 +179,7 @@ python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/Mob

使用PACT量化训练
```
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5
```

输出结果为
Expand Down
32 changes: 25 additions & 7 deletions demo/quant/pact_quant_aware/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@
"Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10,
"Log period in batches.")
add_arg('checkpoint_dir', str, "output",
"checkpoint save dir")
add_arg('checkpoint_dir', str, None,
"checkpoint dir")
add_arg('checkpoint_epoch', int, None,
"checkpoint epoch")
add_arg('output_dir', str, "output/MobileNetV3_large_x1_0",
"model save dir")
add_arg('use_pact', bool, True,
"Whether to use PACT or not.")

Expand Down Expand Up @@ -244,6 +248,7 @@ def train(epoch, compiled_train_prog):
compiled_train_prog,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])

end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
Expand Down Expand Up @@ -279,24 +284,37 @@ def train(epoch, compiled_train_prog):
# train loop
best_acc1 = 0.0
best_epoch = 0
for i in range(args.num_epochs):

start_epoch = 0
if args.checkpoint_dir is not None:
ckpt_path = args.checkpoint_dir
assert args.checkpoint_epoch is not None, "checkpoint_epoch must be set"
start_epoch = args.checkpoint_epoch
fluid.io.load_persistables(
exe, dirname=args.checkpoint_dir, main_program=val_program)
start_step = start_epoch * int(
math.ceil(float(args.total_images) / args.batch_size))
v = fluid.global_scope().find_var('@LR_DECAY_COUNTER@').get_tensor()
v.set(np.array([start_step]).astype(np.float32), place)

for i in range(start_epoch, args.num_epochs):
train(i, compiled_train_prog)
acc1 = test(i, val_program)
fluid.io.save_persistables(
exe,
dirname=os.path.join(args.checkpoint_dir, str(i)),
dirname=os.path.join(args.output_dir, str(i)),
main_program=val_program)
if acc1 > best_acc1:
best_acc1 = acc1
best_epoch = i
fluid.io.save_persistables(
exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'),
dirname=os.path.join(args.output_dir, 'best_model'),
main_program=val_program)
if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')):
if os.path.exists(os.path.join(args.output_dir, 'best_model')):
fluid.io.load_persistables(
exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'),
dirname=os.path.join(args.output_dir, 'best_model'),
main_program=val_program)
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
Expand Down

0 comments on commit a45431c

Please sign in to comment.