diff --git a/docs/zh_CN/models/LUPerson/solider.md b/docs/zh_CN/models/LUPerson/solider.md
new file mode 100644
index 0000000000..3e50ca0224
--- /dev/null
+++ b/docs/zh_CN/models/LUPerson/solider.md
@@ -0,0 +1,27 @@
+# Solider
+
+-----
+## 目录
+
+- [1. 模型介绍](#1)
+- [2. 对齐日志、模型](#2)
+
+
+
+## 1. 模型介绍
+
+Solider是一个语义可控的自监督学习框架,可以从大量未标记的人体图像中学习一般的人类表征,从而最大限度地有利于下游以人类为中心的任务。与已有的自监督学习方法不同,该方法利用人体图像中的先验知识建立伪语义标签,并将更多的语义信息引入到学习的表示中。同时,不同的下游任务往往需要不同比例的语义信息和外观信息,单一的学习表示不能满足所有需求。为了解决这一问题,Solider引入了一种带有语义控制器的条件网络,可以满足下游任务的不同需求。[论文地址](https://arxiv.org/abs/2303.17602)。
+
+
+
+## 2. 对齐日志、模型
+
+| model | weight | log |
+| ----------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
+| swin_tiny_patch4_window7_224 | https://paddleclas.bj.bcebos.com/models/SOLIDER/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams | 链接:https://pan.baidu.com/s/1W5zUFboMMhXETy4HEWbM3Q?pwd=45nx
提取码:45nx |
+| swin_small_patch4_window7_224 | https://paddleclas.bj.bcebos.com/models/SOLIDER/SwinTransformer_small_patch4_window7_224_pretrained.pdparams | 链接:https://pan.baidu.com/s/1sqcUdfv6FyhW9_QgxBUPWA?pwd=letv
提取码:letv |
+| swin_base_patch4_window7_224 | https://paddleclas.bj.bcebos.com/models/SOLIDER/SwinTransformer_base_patch4_window7_224_pretrained.pdparams | 链接:https://pan.baidu.com/s/1S2TgDxDRa72C_3FrP8duiA?pwd=u3d2
提取码:u3d2 |
+
+[1]:基于 LUPerson 数据集预训练
+
+
diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py
index 435fc4389b..cb02226295 100644
--- a/ppcls/arch/backbone/__init__.py
+++ b/ppcls/arch/backbone/__init__.py
@@ -91,6 +91,7 @@
from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu
from .variant_models.efficientnet_variant import EfficientNetB3_watermark
from .variant_models.foundation_vit_variant import CLIP_large_patch14_224_aesthetic
+from .variant_models.swin_transformer_variant import SwinTransformer_tiny_patch4_window7_224_SOLIDER,SwinTransformer_small_patch4_window7_224_SOLIDER,SwinTransformer_base_patch4_window7_224_SOLIDER
from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200
from .model_zoo.wideresnet import WideResNet
from .model_zoo.uniformer import UniFormer_small, UniFormer_small_plus, UniFormer_small_plus_dim64, UniFormer_base, UniFormer_base_ls
diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py
index 9d2c1b88dd..a177462d1e 100644
--- a/ppcls/arch/backbone/legendary_models/swin_transformer.py
+++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py
@@ -359,12 +359,8 @@ def __init__(self,
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+ self.check_condition()
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
@@ -412,6 +408,13 @@ def __init__(self,
self.register_buffer("attn_mask", attn_mask)
+ def check_condition(self):
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
@@ -835,7 +838,8 @@ def _load_pretrained(pretrained,
model_url,
use_ssld=False,
use_imagenet22k_pretrained=False,
- use_imagenet22kto1k_pretrained=False):
+ use_imagenet22kto1k_pretrained=False,
+ **kwargs):
if pretrained is False:
pass
elif pretrained is True:
@@ -988,4 +992,4 @@ def SwinTransformer_large_patch4_window12_384(
use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
- return model
+ return model
\ No newline at end of file
diff --git a/ppcls/arch/backbone/variant_models/__init__.py b/ppcls/arch/backbone/variant_models/__init__.py
index 80f7a7e9fd..4a27162089 100644
--- a/ppcls/arch/backbone/variant_models/__init__.py
+++ b/ppcls/arch/backbone/variant_models/__init__.py
@@ -2,3 +2,4 @@
from .vgg_variant import VGG19Sigmoid
from .pp_lcnet_variant import PPLCNet_x2_5_Tanh
from .pp_lcnetv2_variant import PPLCNetV2_base_ShiTu
+from .swin_transformer_variant import SwinTransformer_base_patch4_window7_224_SOLIDER,SwinTransformer_small_patch4_window7_224_SOLIDER,SwinTransformer_tiny_patch4_window7_224_SOLIDER
diff --git a/ppcls/arch/backbone/variant_models/swin_transformer_variant.py b/ppcls/arch/backbone/variant_models/swin_transformer_variant.py
new file mode 100644
index 0000000000..1e6632f6e8
--- /dev/null
+++ b/ppcls/arch/backbone/variant_models/swin_transformer_variant.py
@@ -0,0 +1,355 @@
+import numpy as np
+import paddle
+import paddle.nn as nn
+from ..legendary_models.swin_transformer import SwinTransformer, _load_pretrained, \
+ PatchEmbed, BasicLayer, SwinTransformerBlock
+
+MODEL_URLS_SOLIDER = {
+ "SwinTransformer_tiny_patch4_window7_224_SOLIDER":
+ 'https://paddleclas.bj.bcebos.com/models/SOLIDER/SwinTransformer_tiny_patch4_window7_224_pretrained.pdparams',
+ "SwinTransformer_small_patch4_window7_224_SOLIDER":
+ 'https://paddleclas.bj.bcebos.com/models/SOLIDER/SwinTransformer_small_patch4_window7_224_pretrained.pdparams',
+ "SwinTransformer_base_patch4_window7_224_SOLIDER":
+ 'https://paddleclas.bj.bcebos.com/models/SOLIDER/SwinTransformer_base_patch4_window7_224_pretrained.pdparams'
+}
+
+__all__ = list(MODEL_URLS_SOLIDER.keys())
+
+
+class PatchEmbed_SOLIDER(PatchEmbed):
+ def forward(self, x):
+ x = self.proj(x)
+ out_size = (x.shape[2], x.shape[3])
+ x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x, out_size
+
+
+class SwinTransformerBlock_SOLIDER(SwinTransformerBlock):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm):
+ super(SwinTransformerBlock_SOLIDER, self).__init__(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=shift_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.check_condition()
+
+ def check_condition(self):
+ if min(self.input_resolution) < self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+
+class BasicLayer_SOLIDER(BasicLayer):
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False):
+
+ super(BasicLayer_SOLIDER, self).__init__(
+ dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint
+ )
+ # build blocks
+ self.blocks = nn.LayerList([
+ SwinTransformerBlock_SOLIDER(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer) for i in range(depth)
+ ])
+
+ def forward(self, x):
+ for blk in self.blocks:
+ x = blk(x)
+
+ if self.downsample is not None:
+ x_down = self.downsample(x)
+ return x_down, x
+ else:
+ return x, x
+
+
+class PatchMerging_SOLIDER(nn.Layer):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.sampler = nn.Unfold(kernel_sizes=2, strides=2)
+ self.norm = norm_layer(4 * dim)
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, "x size ({}*{}) are not even.".format(
+ H, W)
+
+ x = x.reshape([B, H, W, C]).transpose([0, 3, 1, 2])
+
+ x = self.sampler(x)
+ x = x.transpose([0, 2, 1])
+ x = self.norm(x)
+ x = self.reduction(x)
+ return x
+
+
+class SwinTransformer_SOLIDER(SwinTransformer):
+ def __init__(self,
+ embed_dim=96,
+ img_size=224,
+ patch_size=4,
+ in_chans=3,
+ class_num=1000,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ out_indices=(0, 1, 2, 3),
+ semantic_weight=1.0,
+ use_checkpoint=False,
+ **kwargs):
+ super(SwinTransformer_SOLIDER, self).__init__()
+ patches_resolution = self.patch_embed.patches_resolution
+ self.num_classes = num_classes = class_num
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ # stochastic depth
+ dpr = np.linspace(0, drop_path_rate,
+ sum(depths)).tolist() # stochastic depth decay rule
+ self.patch_embed = PatchEmbed_SOLIDER(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ self.out_indices = out_indices
+ # build layers
+ self.layers = nn.LayerList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer_SOLIDER(
+ dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging_SOLIDER
+ if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ self.num_features_s = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ for i in out_indices:
+ layer = norm_layer(self.num_features_s[i])
+ layer_name = f'norm{i}'
+ self.add_sublayer(layer_name, layer)
+ self.avgpool = nn.AdaptiveAvgPool2D(1)
+
+ # semantic embedding
+ self.semantic_weight = semantic_weight
+ if self.semantic_weight >= 0:
+ self.semantic_embed_w = nn.LayerList()
+ self.semantic_embed_b = nn.LayerList()
+ for i in range(len(depths)):
+ if i >= len(depths) - 1:
+ i = len(depths) - 2
+ semantic_embed_w = nn.Linear(2, self.num_features_s[i + 1])
+ semantic_embed_b = nn.Linear(2, self.num_features_s[i + 1])
+ self._init_weights(semantic_embed_w)
+ self._init_weights(semantic_embed_b)
+ self.semantic_embed_w.append(semantic_embed_w)
+ self.semantic_embed_b.append(semantic_embed_b)
+ self.softplus = nn.Softplus()
+ self.head = nn.Linear(
+ self.num_features,
+ num_classes) if self.num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x, semantic_weight=None):
+ if self.semantic_weight >= 0 and semantic_weight is None:
+ w = paddle.ones((x.shape[0], 1)) * self.semantic_weight
+ w = paddle.concat([w, 1 - w], axis=-1)
+ semantic_weight = w.cuda()
+ x, hw_shape = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ outs = []
+
+ for i, layer in enumerate(self.layers):
+ x, out = layer(x)
+ if self.semantic_weight >= 0:
+ sw = self.semantic_embed_w[i](semantic_weight).unsqueeze(1)
+ sb = self.semantic_embed_b[i](semantic_weight).unsqueeze(1)
+ x = x * self.softplus(sw) + sb
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ out = norm_layer(out)
+ out = out.reshape([-1, *hw_shape,
+ self.num_features_s[i]]).transpose([0, 3, 1, 2])
+ hw_shape = [item // 2 for item in hw_shape]
+ outs.append(out)
+
+ x = self.avgpool(outs[-1]) # B C 1
+ x = paddle.flatten(x, 1)
+
+ return x
+
+
+def SwinTransformer_tiny_patch4_window7_224_SOLIDER(
+ pretrained=False,
+ **kwargs):
+ model = SwinTransformer_SOLIDER(
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ drop_path_rate=0.2, # if imagenet22k or imagenet22kto1k, set drop_path_rate=0.1
+ **kwargs)
+ _load_pretrained(
+ pretrained,
+ model=model,
+ model_url=MODEL_URLS_SOLIDER["SwinTransformer_tiny_patch4_window7_224_SOLIDER"],
+ **kwargs)
+ return model
+
+
+def SwinTransformer_small_patch4_window7_224_SOLIDER(
+ pretrained=False,
+ **kwargs):
+ model = SwinTransformer_SOLIDER(
+ embed_dim=96,
+ depths=[2, 2, 18, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ drop_path_rate=0.3, # if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
+ **kwargs)
+ _load_pretrained(
+ pretrained,
+ model=model,
+ model_url=MODEL_URLS_SOLIDER["SwinTransformer_small_patch4_window7_224_SOLIDER"],
+ **kwargs)
+ return model
+
+
+def SwinTransformer_base_patch4_window7_224_SOLIDER(
+ pretrained=False,
+ **kwargs):
+ model = SwinTransformer_SOLIDER(
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=7,
+ drop_path_rate=0.5, # if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
+ **kwargs)
+ _load_pretrained(
+ pretrained,
+ model=model,
+ model_url=MODEL_URLS_SOLIDER["SwinTransformer_base_patch4_window7_224_SOLIDER"],
+ **kwargs)
+ return model