Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
remove the conflict checkpoint function
  • Loading branch information
psky1111 committed Oct 18, 2023
1 parent 63c3e52 commit d7a7dad
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 33 deletions.
34 changes: 3 additions & 31 deletions ppcls/arch/backbone/legendary_models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def __init__(self,
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio

self.check_condition()
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
Expand All @@ -407,39 +406,12 @@ def __init__(self,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
"""
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = paddle.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.reshape(
[-1, self.window_size * self.window_size])
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

huns = -100.0 * paddle.ones_like(attn_mask)
attn_mask = huns * (attn_mask != 0).astype("float32")
else:
attn_mask = None
"""
H, W = self.input_resolution
attn_mask = paddle.zeros([1, H, W, 1])

self.register_buffer("attn_mask", attn_mask)

def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for shifted window multihead self attention
Expand Down Expand Up @@ -467,7 +439,7 @@ def get_attn_mask(self, height, width, dtype):
else:
attn_mask = None
return attn_mask

def forward(self, x, input_dimensions):
H, W = input_dimensions
B, L, C = x.shape
Expand Down Expand Up @@ -1051,4 +1023,4 @@ def SwinTransformer_large_patch4_window12_384(
use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model
return model
2 changes: 0 additions & 2 deletions ppcls/loss/afdloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import paddle.nn.functional as F
import paddle
import numpy as np
import matplotlib.pyplot as plt
import cv2
import warnings
warnings.filterwarnings('ignore')

Expand Down

0 comments on commit d7a7dad

Please sign in to comment.