Skip to content

Commit

Permalink
add vit large configs and checkpoint (PaddlePaddle#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang committed Nov 11, 2022
1 parent 08300f8 commit 170f8c0
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 6 deletions.
4 changes: 2 additions & 2 deletions plsc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def worker_init_fn(worker_id):
# set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"]
self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device))
logger.info('train with paddle {}, commit id {} and device {}'.format(
paddle.__version__, paddle.__git_commit__[:8], self.device))

class_num = config["Model"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num})
Expand Down
10 changes: 6 additions & 4 deletions task/classification/vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ We provide more directly runnable configurations, see [ViT Configurations](./con

## Models

| Model | Phase | Dataset | Configs | GPUs | Img/sec | Top1 Acc | Pre-trained checkpoint | Fine-tuned checkpoint | Log |
| ------------ | -------- | ------------ | ------------------------------------------------------------ | --------- | ------- | -------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| ViT-B_16_224 | pretrain | ImageNet2012 | [config](./configs/ViT_base_patch16_224_in1k_1n8c_dp_fp16o2.yaml) | A100*N1C8 | 3583 | 0.75196 | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-224.pdparams) | - | [log](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-224.log) |
| ViT-B_16_384 | finetune | ImageNet2012 | [config](./configs/ViT_base_patch16_384_ft_in1k_1n8c_dp_fp16o2.yaml) | A100*N1C8 | 719 | 0.77972 | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-224.pdparams) | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-384.pdparams) | [log](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-384.log) |
| Model | Phase | Dataset | Configs | GPUs | Img/sec | Top1 Acc | Pre-trained checkpoint | Fine-tuned checkpoint | Log |
| ------------ | -------- | ------------ | ------------------------------------------------------------ | ---------- | ------- | -------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| ViT-B_16_224 | pretrain | ImageNet2012 | [config](./configs/ViT_base_patch16_224_in1k_1n8c_dp_fp16o2.yaml) | A100*N1C8 | 3583 | 0.75196 | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-224.pdparams) | - | [log](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-224.log) |
| ViT-B_16_384 | finetune | ImageNet2012 | [config](./configs/ViT_base_patch16_384_ft_in1k_1n8c_dp_fp16o2.yaml) | A100*N1C8 | 719 | 0.77972 | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-224.pdparams) | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-384.pdparams) | [log](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-384.log) |
| ViT-L_16_224 | pretrain | ImageNet21K | [config](./configs/ViT_large_patch16_224_in22k_4n32c_dp_fp16o2.yaml) | A100*N4C32 | 5256 | - | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet21k-ViT-L_16-224.pdparams) | - | [log](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet21k-ViT-L_16-224.log) |
| ViT-L_16_384 | finetune | ImageNet2012 | [config](./configs/ViT_large_patch16_384_in1k_ft_4n32c_dp_fp16o2.yaml) | A100*N4C32 | 934 | 0.84926 | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet21k-ViT-L_16-224.pdparams) | [download](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet21k%2Bimagenet2012-ViT-L_16-384.pdparams) | [log](https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet21k%2Bimagenet2012-ViT-L_16-384.log) |


## Citations
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# global configs
Global:
checkpoint: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
max_num_latest_checkpoint: 0
eval_during_train: True
eval_interval: 1
eval_unit: "epoch"
accum_steps: 1
epochs: 90
print_batch_step: 10
use_visualdl: False
seed: 2021

# FP16 setting
FP16:
level: O2
GradScaler:
init_loss_scaling: 65536.0

DistributedStrategy:
data_parallel: True

# model architecture
Model:
name: ViT_large_patch16_224
class_num: 18576
drop_rate: 0.1

# loss function config for traing/eval process
Loss:
Train:
- ViTCELoss:
weight: 1.0
epsilon: 0.0001
Eval:
- CELoss:
weight: 1.0

LRScheduler:
name: ViTLRScheduler
learning_rate: 1e-3
decay_type: linear
warmup_steps: 10000

Optimizer:
name: AdamW
betas: (0.9, 0.999)
epsilon: 1e-6
weight_decay: 0.15
exp_avg_force_fp32: True

# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ImageNet22K/
multi_label: True
class_num: 18576
cls_label_path: ./dataset/ImageNet22K/filter_train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: True
shuffle: True
loader:
num_workers: 8
use_shared_memory: True

Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ImageNet22K/
cls_label_path: ./dataset/ImageNet22K/filter_val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CenterCropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:

sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True

Metric:
Eval:
- TopkAcc:
topk: [1, 5]

Export:
export_type: paddle
input_shape: [None, 3, 224, 224]
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# global configs
Global:
checkpoint: null
finetune: True
pretrained_model: ./pretrained/ViT_large_patch16_224/latest
output_dir: ./output/
device: gpu
save_interval: 1
max_num_latest_checkpoint: 0
eval_during_train: True
eval_interval: 1
eval_unit: "epoch"
accum_steps: 1
epochs: 8
print_batch_step: 10
use_visualdl: False
seed: 2021

# FP16 setting
FP16:
level: O2
GradScaler:
init_loss_scaling: 65536.0

DistributedStrategy:
data_parallel: True

# model architecture
Model:
name: ViT_large_patch16_384
class_num: 1000
drop_rate: 0.1

# loss function config for traing/eval process
Loss:
Train:
- ViTCELoss:
type: softmax
weight: 1.0
Eval:
- CELoss:
weight: 1.0

LRScheduler:
name: ViTLRScheduler
learning_rate: 0.015
decay_type: cosine
warmup_steps: 500

Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 0.0001
grad_clip:
name: ClipGradByGlobalNorm
clip_norm: 1.2

# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
class_num: 1000
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 384
scale: [0.05, 1.0]
interpolation: bilinear
backend: pil
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:

sampler:
name: DistributedBatchSampler
batch_size: 16 # total batchsize 512
drop_last: True
shuffle: True
loader:
num_workers: 8
use_shared_memory: True

Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 384
interpolation: bilinear
backend: pil
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:

sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True

Metric:
Eval:
- TopkAcc:
topk: [1, 5]

Export:
export_type: paddle
input_shape: [None, 3, 384, 384]

0 comments on commit 170f8c0

Please sign in to comment.