diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index f49f6c6627..4534df6f43 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -147,6 +147,8 @@ def pading_for_not_divisible(pixel_values, function="split"): if isinstance(patch_size, int): patch_size = (patch_size, patch_size) + if height % patch_size[0] == 0 and width % patch_size[1] == 0: + return pixel_values, (0, 0, 0, 0, 0, 0, 0, 0) if function == "split": pading_width = patch_size[1] - width % patch_size[1] pading_height = patch_size[0] - height % patch_size[0] @@ -407,7 +409,7 @@ def __init__(self, act_layer=act_layer, drop=drop) H, W = self.input_resolution - attn_mask = paddle.zeros([1, H, W, 1]) + attn_mask = None self.register_buffer("attn_mask", attn_mask) @@ -450,6 +452,9 @@ def forward(self, x, input_dimensions): x, pad_values = pading_for_not_divisible(x, H, W, self.window_size, "BHWC") _, height_pad, width_pad, _ = x.shape + + padding_state = pad_values[3] > 0 or pad_values[ + 5] > 0 # change variable name # cyclic shift if self.shift_size > 0: shifted_x = RollWrapper.roll( @@ -465,7 +470,9 @@ def forward(self, x, input_dimensions): C]) # nW*B, window_size*window_size, C # W-MSA/SW-MSA + #check did it need to calculate again attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype) + attn_windows = self.attn( x_windows, mask=attn_mask) # nW*B, window_size*window_size, C @@ -484,8 +491,7 @@ def forward(self, x, input_dimensions): else: x = shifted_x - was_padded = pad_values[3] > 0 or pad_values[5] > 0 - if was_padded: + if padding_state: x = x[:, :H, :W, :] x = x.reshape([B, H * W, C]) diff --git a/ppcls/arch/backbone/model_zoo/foundation_vit.py b/ppcls/arch/backbone/model_zoo/foundation_vit.py index 46a9e517a0..bddd0d68be 100644 --- a/ppcls/arch/backbone/model_zoo/foundation_vit.py +++ b/ppcls/arch/backbone/model_zoo/foundation_vit.py @@ -114,6 +114,8 @@ def pading_for_not_divisible(pixel_values, function="split"): if isinstance(patch_size, int): patch_size = (patch_size, patch_size) + if height % patch_size[0] == 0 and width % patch_size[1] == 0: + return pixel_values, None if function == "split": pading_width = patch_size[1] - width % patch_size[1] pading_height = patch_size[0] - height % patch_size[0]