-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add MaskFormer #2789
[Feature] Add MaskFormer #2789
Conversation
paddleseg/transforms/transforms.py
Outdated
@@ -245,11 +245,18 @@ class ResizeByShort: | |||
short_size (int): The target size of short side. | |||
""" | |||
|
|||
def __init__(self, short_size): | |||
def __init__(self, short_size, max_size=2048): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是ResizeByShort的标准用法? 会改变原先ResizeByShort的逻辑。 如果不是标准用法,可以自己实现一个ResizeByShortWithMax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
设置的max_size很大,就不会影响之前的逻辑。
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果参考的比较多,写上参考实现的链接
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加个空行吧, 和copyright混合了
paddleseg/core/train.py
Outdated
loss_list.append(coef_i * loss_i(logits, labels)) | ||
if targets is not None: | ||
loss_list.append(coef_i * loss_i(logits, targets)) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
train.py里面的实现,我再想想。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于loss计算,马上会更新一个范式,会放到model里面实现特定的loss计算。所以这部分等那个pr合入后再改下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok~
paddleseg/datasets/ade.py
Outdated
data['label'] = label | ||
return data | ||
data['label'] = label - 1 | ||
data['img'] = paddle.to_tensor(data['img']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改为和之前一致
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此外也要添加tipc和验证预测模型
paddleseg/utils/utils.py
Outdated
@@ -138,6 +138,7 @@ def download_pretrained_model(pretrained_model): | |||
pretrained_model = download_file_and_uncompress( | |||
pretrained_model, | |||
savepath=_dir, | |||
cover=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个会覆盖原先下载的预训练权重等,导致每次都下载,很费时
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果不添加cover,如果之前用户本地有model.params就会加载之前的,但是那个模型参数可能并不是这个模型的,就存在加载不上参数的问题,导致精度差异不符合预期。
hue_range: 18 | ||
hue_prob: 1.0 | ||
- type: RandomHorizontalFlip | ||
- type: Padding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前面已经有RandomPaddingCrop,不需要padding了吧。
如果四个yml文件只是model不一样,一个yml文件作为base,其中三个yml文件可以使__base__包含。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除,除了model还有训练数据处理不一样,目前已经修改为最大化继承。
num_classes: 150 | ||
backbone: | ||
type: SwinTransformer_base_patch4_window7_384_maskformer | ||
pretrained: https://bj.bcebos.com/paddleseg/paddleseg/dygraph/ade20k/maskformer_ade20k_swin_base/pretrain/model.pdparams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretrained还是针对整个maskformer模型,不是针对backbone啊
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是,这个是只有swin的模型,训练日志里加载的参数数量也可以看出来
paddleseg/cvlibs/param_init.py
Outdated
|
||
|
||
def th_multihead_fill(layer, qkv_same_embed_dim=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个th开头的初始化函数,是所有模型都通用的还是maskformer专用的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是可以用于所有模型里multihead attention算子的初始化
paddleseg/core/train.py
Outdated
@@ -176,6 +179,26 @@ def train(model, | |||
edges = None | |||
if 'edge' in data.keys(): | |||
edges = data['edge'].astype('int64') | |||
|
|||
targets = None | |||
if "instances" in data.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个写法肯定是不行的,过于trick。 这个数据处理不可以放到dataset里面吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是对一个batch的数据进行处理,从而不能放入dataset,已经放入到损失当中。
paddleseg/core/train.py
Outdated
loss_list.append(coef_i * loss_i(logits, labels)) | ||
if targets is not None: | ||
loss_list.append(coef_i * loss_i(logits, targets)) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于loss计算,马上会更新一个范式,会放到model里面实现特定的loss计算。所以这部分等那个pr合入后再改下。
paddleseg/datasets/maskedade.py
Outdated
import paddle.nn as nn | ||
|
||
from paddleseg.datasets import ADE20K | ||
from paddleseg.utils.download import download_file_and_uncompress |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
download_file_and_uncompress 没用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddleseg/datasets/maskedade.py
Outdated
@manager.DATASETS.add_component | ||
class MaskedADE20K(ADE20K): | ||
""" | ||
ADE20K dataset `http://sceneparsing.csail.mit.edu/`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
说明一下这个MaskedADE20K具体是什么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Models
Description
Add Maskformer.
目前已经使用PaddleSeg中已有的transforms,并在paddle2.4下环境测试精度mIOU为47.93%,满足复现目标46.7%:
在small 配置上验证,训练精度和评估精度均可以达到50.4,超过论文中精度的49.8:
完成tipc部署: