Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add MaskFormer #2789

Merged
merged 35 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4790c5d
scratch_1
shiyutang Nov 8, 2022
41fcd07
scratch_2
shiyutang Nov 10, 2022
35ba0d4
forward_align_finished
shiyutang Nov 11, 2022
d2bf25b
update
shiyutang Nov 15, 2022
933b367
backwardaligned_dot2
shiyutang Nov 16, 2022
360f1db
backward_blankin
shiyutang Nov 16, 2022
db3a3e2
train_divergy
shiyutang Nov 17, 2022
7a90e9b
fix_optimizer
shiyutang Nov 18, 2022
0abed97
fix_metric
shiyutang Nov 21, 2022
0e466d4
align_train_load+metric467
shiyutang Nov 21, 2022
080bac2
newest
shiyutang Nov 22, 2022
7de03d8
paddle2.2.2
shiyutang Nov 23, 2022
cc31da5
update_init_aug
shiyutang Nov 29, 2022
65d36ee
rm_redundant
shiyutang Nov 30, 2022
a75a0c3
update
shiyutang Nov 30, 2022
ea94836
update
shiyutang Nov 30, 2022
e05f390
update
shiyutang Dec 1, 2022
6e25d85
update_init
shiyutang Dec 1, 2022
24f8aed
train_with2.4and new aug
shiyutang Dec 1, 2022
4847dc7
validate_on2.4
shiyutang Dec 5, 2022
a3d151a
update_init
shiyutang Dec 13, 2022
d6fad22
add_multiple_backbone+initfree
shiyutang Dec 19, 2022
8738173
Merge branch 'develop' into paddle2.2.2
shiyutang Jan 12, 2023
14933b9
fix_format
shiyutang Jan 12, 2023
72c42c2
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSeg i…
shiyutang Jan 12, 2023
3d54334
fix_by_comment_test_train_ok_no_acc
shiyutang Jan 13, 2023
b86a511
Merge branch 'paddle2.2.2' of https://github.com/shiyutang/PaddleSeg …
shiyutang Jan 13, 2023
dbc43c0
valid_47.6
shiyutang Jan 13, 2023
d7479be
validate_train_47.93
shiyutang Jan 19, 2023
c403cad
valid_train_small_50.4
shiyutang Jan 28, 2023
f658d73
fix_by_comment
shiyutang Jan 30, 2023
85c36a3
maskformer_tipc
shiyutang Jan 31, 2023
094d1a2
fix_large_yml
shiyutang Feb 1, 2023
f16a816
compact_train_loss
shiyutang Mar 1, 2023
fb9668a
Merge branch 'develop' into paddle2.2.2
shiyutang Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion configs/_base_/ade20k.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ val_dataset:
- type: Normalize
mode: val


optimizer:
type: SGD
momentum: 0.9
Expand Down
17 changes: 17 additions & 0 deletions configs/maskformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Per-Pixel Classification is Not All You Need for Semantic Segmentation

## Reference

> Cheng, Bowen, Alex Schwing, and Alexander Kirillov. "Per-pixel classification is not all you need for semantic segmentation." Advances in Neural Information Processing Systems 34 (2021): 17864-17875.

## Performance

### ADE20k

| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|Maskformer|SwinTransformer|512x512|160000|47.6|-|-|[model](https://bj.bcebos.com/paddleseg/dygraph/ade20k/maskformer_ade20k_swin_tiny/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/maskformer_ade20k_swin_tiny/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=e59773eaad87f677837add5ff110441e)|
juncaipeng marked this conversation as resolved.
Show resolved Hide resolved

* Maskformer support different backbone including tiny, small, base and large. Due to long training time, the accuracy result is not provided.

* Maskformer-Base and Maskformer-Large need to be evaled with multi-scale and flip by default.
73 changes: 73 additions & 0 deletions configs/maskformer/maskformer_ade20k_swin_base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
batch_size: 4
iters: 160000

train_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280, 1344]
max_size: 2560
- type: RandomPaddingCrop
crop_size: [640, 640]
- type: RandomDistort
brightness_range: 0.125
brightness_prob: 1.0
contrast_range: 0.5
contrast_prob: 1.0
saturation_range: 0.5
saturation_prob: 1.0
hue_range: 18
hue_prob: 1.0
- type: RandomHorizontalFlip
to_mask: True
size_divisibility: 640
normalize: True


val_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: 512
mode: val
to_mask: True
normalize: True

model:
type: MaskFormer
num_classes: 150
backbone:
type: SwinTransformer_base_patch4_window7_384_maskformer
pretrained: https://bj.bcebos.com/paddleseg/paddleseg/dygraph/ade20k/maskformer_ade20k_swin_base/pretrain/model.pdparams

optimizer:
type: AdamW
backbone_lr_mult: 1.0
weight_decay: 0.01

gradient_clipper:
juncaipeng marked this conversation as resolved.
Show resolved Hide resolved
enabled: True
clip_value: 0.01

lr_scheduler:
type: PolynomialDecay
warmup_iters: 1500
warmup_start_lr: 6.0e-11
learning_rate: 6.0e-05
end_lr: 0
power: 0.9

loss:
types:
- type: MaskFormerLoss
num_classes: 150
eos_coef: 0.1
coef: [1]

export:
juncaipeng marked this conversation as resolved.
Show resolved Hide resolved
transforms:
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
73 changes: 73 additions & 0 deletions configs/maskformer/maskformer_ade20k_swin_large.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
batch_size: 4
iters: 160000

train_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280, 1344]
max_size: 2560
- type: RandomPaddingCrop
crop_size: [640, 640]
- type: RandomDistort
brightness_range: 0.125
brightness_prob: 1.0
contrast_range: 0.5
contrast_prob: 1.0
saturation_range: 0.5
saturation_prob: 1.0
hue_range: 18
hue_prob: 1.0
- type: RandomHorizontalFlip
to_mask: True
size_divisibility: 640
normalize: True


val_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: 512
mode: val
to_mask: True
normalize: True

model:
type: MaskFormer
num_classes: 150
backbone:
type: SwinTransformer_large_patch4_window7_384_maskformer
pretrained: https://bj.bcebos.com/paddleseg/paddleseg/dygraph/ade20k/maskformer_ade20k_swin_large/pretrain/model.pdparams

optimizer:
type: AdamW
backbone_lr_mult: 1.0
weight_decay: 0.01

gradient_clipper:
enabled: True
clip_value: 0.01

lr_scheduler:
type: PolynomialDecay
warmup_iters: 1500
warmup_start_lr: 6.0e-11
learning_rate: 6.0e-05
end_lr: 0
power: 0.9

loss:
types:
- type: MaskFormerLoss
num_classes: 150
eos_coef: 0.1
coef: [1]

export:
transforms:
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
73 changes: 73 additions & 0 deletions configs/maskformer/maskformer_ade20k_swin_small.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
batch_size: 4
iters: 160000

train_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: [256, 307, 358, 409, 460, 512, 563, 614, 665, 716, 768, 819, 870, 921, 972, 1024]
max_size: 2048
- type: RandomPaddingCrop
crop_size: [512, 512]
- type: RandomDistort
brightness_range: 0.125
brightness_prob: 1.0
contrast_range: 0.5
contrast_prob: 1.0
saturation_range: 0.5
saturation_prob: 1.0
hue_range: 18
hue_prob: 1.0
- type: RandomHorizontalFlip
to_mask: True
size_divisibility: 512
normalize: True


val_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: 512
mode: val
to_mask: True
normalize: True

model:
type: MaskFormer
num_classes: 150
backbone:
type: SwinTransformer_small_patch4_window7_224_maskformer
pretrained: https://bj.bcebos.com/paddleseg/paddleseg/dygraph/ade20k/maskformer_ade20k_swin_small/pretrain/model.pdparams

optimizer:
type: AdamW
backbone_lr_mult: 1.0
weight_decay: 0.01

gradient_clipper:
enabled: True
clip_value: 0.01

lr_scheduler:
type: PolynomialDecay
warmup_iters: 1500
warmup_start_lr: 6.0e-11
learning_rate: 6.0e-05
end_lr: 0
power: 0.9

loss:
types:
- type: MaskFormerLoss
num_classes: 150
eos_coef: 0.1
coef: [1]

export:
transforms:
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
73 changes: 73 additions & 0 deletions configs/maskformer/maskformer_ade20k_swin_tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
batch_size: 4
juncaipeng marked this conversation as resolved.
Show resolved Hide resolved
iters: 160000

train_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: [256, 307, 358, 409, 460, 512, 563, 614, 665, 716, 768, 819, 870, 921, 972, 1024]
max_size: 2048
- type: RandomPaddingCrop
crop_size: [512, 512]
- type: RandomDistort
brightness_range: 0.125
brightness_prob: 1.0
contrast_range: 0.5
contrast_prob: 1.0
saturation_range: 0.5
saturation_prob: 1.0
hue_range: 18
hue_prob: 1.0
- type: RandomHorizontalFlip
to_mask: True
size_divisibility: 512
normalize: True


val_dataset:
type: ADE20K
dataset_root: data/ADEChallengeData2016/
transforms:
- type: ResizeByShort
short_size: 512
mode: val
to_mask: True
normalize: True

model:
type: MaskFormer
num_classes: 150
backbone:
type: SwinTransformer_tiny_patch4_window7_224_maskformer
pretrained: https://bj.bcebos.com/paddleseg/paddleseg/dygraph/ade20k/maskformer_ade20k_swin_tiny/pretrain/model.pdparams

optimizer:
type: AdamW
backbone_lr_mult: 1.0
weight_decay: 0.01

gradient_clipper:
enabled: True
clip_value: 0.01

lr_scheduler:
type: PolynomialDecay
warmup_iters: 1500
warmup_start_lr: 6.0e-11
learning_rate: 6.0e-05
end_lr: 0
power: 0.9

loss:
types:
- type: MaskFormerLoss
num_classes: 150
eos_coef: 0.1
coef: [1]

export:
transforms:
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
33 changes: 29 additions & 4 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def check_logits_losses(logits_list, losses):
.format(len_logits, len_losses))


def loss_computation(logits_list, labels, edges, losses):
def loss_computation(logits_list, labels, edges, losses, targets=None):
check_logits_losses(logits_list, losses)
loss_list = []
for i in range(len(logits_list)):
Expand All @@ -52,7 +52,10 @@ def loss_computation(logits_list, labels, edges, losses):
loss_list.append(coef_i *
loss_i(logits_list[0], logits_list[1].detach()))
else:
loss_list.append(coef_i * loss_i(logits, labels))
if targets is not None:
loss_list.append(coef_i * loss_i(logits, targets))
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train.py里面的实现,我再想想。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于loss计算,马上会更新一个范式,会放到model里面实现特定的loss计算。所以这部分等那个pr合入后再改下。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok~

loss_list.append(coef_i * loss_i(logits, labels))
return loss_list


Expand Down Expand Up @@ -176,6 +179,26 @@ def train(model,
edges = None
if 'edge' in data.keys():
edges = data['edge'].astype('int64')

targets = None
if "instances" in data.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个写法肯定是不行的,过于trick。 这个数据处理不可以放到dataset里面吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是对一个batch的数据进行处理,从而不能放入dataset,已经放入到损失当中。

targets = []
# split targets in the batch
for target_per_image_idx in range(batch_size):
gt_masks = data['instances']['gt_masks'][
target_per_image_idx, ...]
padded_masks = paddle.zeros(
(gt_masks.shape[0], gt_masks.shape[-2],
gt_masks.shape[-1]),
dtype=gt_masks.dtype)
padded_masks[:, :gt_masks.shape[1], :gt_masks.shape[
2]] = gt_masks
targets.append({
"labels": data['instances']['gt_classes'][
target_per_image_idx, ...],
"masks": padded_masks
})

if hasattr(model, 'data_format') and model.data_format == 'NHWC':
images = images.transpose((0, 2, 3, 1))

Expand All @@ -193,7 +216,8 @@ def train(model,
logits_list=logits_list,
labels=labels,
edges=edges,
losses=losses)
losses=losses,
targets=targets)
loss = sum(loss_list)

scaled = scaler.scale(loss) # scale the loss
Expand All @@ -208,7 +232,8 @@ def train(model,
logits_list=logits_list,
labels=labels,
edges=edges,
losses=losses)
losses=losses,
targets=targets)
loss = sum(loss_list)
loss.backward()
optimizer.step()
Expand Down
Loading