Skip to content

Commit

Permalink
Revert "Revert "fix resolution problem for swin transformer and clip …
Browse files Browse the repository at this point in the history
…vit (PaddlePaddle#3021)""

This reverts commit 174db43.
  • Loading branch information
psky1111 committed Nov 3, 2023
1 parent 174db43 commit 13a974d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
12 changes: 9 additions & 3 deletions ppcls/arch/backbone/legendary_models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def pading_for_not_divisible(pixel_values,
function="split"):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
if height % patch_size[0] == 0 and width % patch_size[1] == 0:
return pixel_values, (0, 0, 0, 0, 0, 0, 0, 0)
if function == "split":
pading_width = patch_size[1] - width % patch_size[1]
pading_height = patch_size[0] - height % patch_size[0]
Expand Down Expand Up @@ -407,7 +409,7 @@ def __init__(self,
act_layer=act_layer,
drop=drop)
H, W = self.input_resolution
attn_mask = paddle.zeros([1, H, W, 1])
attn_mask = None

self.register_buffer("attn_mask", attn_mask)

Expand Down Expand Up @@ -450,6 +452,9 @@ def forward(self, x, input_dimensions):
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
"BHWC")
_, height_pad, width_pad, _ = x.shape

padding_state = pad_values[3] > 0 or pad_values[
5] > 0 # change variable name
# cyclic shift
if self.shift_size > 0:
shifted_x = RollWrapper.roll(
Expand All @@ -465,7 +470,9 @@ def forward(self, x, input_dimensions):
C]) # nW*B, window_size*window_size, C

# W-MSA/SW-MSA
#check did it need to calculate again
attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype)

attn_windows = self.attn(
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C

Expand All @@ -484,8 +491,7 @@ def forward(self, x, input_dimensions):
else:
x = shifted_x

was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
if padding_state:
x = x[:, :H, :W, :]
x = x.reshape([B, H * W, C])

Expand Down
2 changes: 2 additions & 0 deletions ppcls/arch/backbone/model_zoo/foundation_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def pading_for_not_divisible(pixel_values,
function="split"):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
if height % patch_size[0] == 0 and width % patch_size[1] == 0:
return pixel_values, None
if function == "split":
pading_width = patch_size[1] - width % patch_size[1]
pading_height = patch_size[0] - height % patch_size[0]
Expand Down

0 comments on commit 13a974d

Please sign in to comment.