diff --git a/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py b/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py index 1febea39..51982b00 100644 --- a/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py +++ b/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py @@ -376,46 +376,6 @@ def forward(self, x, x_size, rpi_sa, attn_mask): return x -class PatchMerging(nn.Module): - r"""Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: b, h*w, c - """ - h, w = self.input_resolution - b, seq_len, c = x.shape - assert seq_len == h * w, "input feature has wrong size" - assert h % 2 == 0 and w % 2 == 0, f"x size ({h}*{w}) are not even." - - x = x.view(b, h, w, c) - - x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c - x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c - x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c - x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c - x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c - x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c - - x = self.norm(x) - x = self.reduction(x) - - return x - - class OCAB(nn.Module): # overlapping cross-attention block @@ -1132,14 +1092,6 @@ def calculate_mask(self, x_size): return attn_mask - @torch.jit.ignore # type: ignore - def no_weight_decay(self): - return {"absolute_pos_embed"} - - @torch.jit.ignore # type: ignore - def no_weight_decay_keywords(self): - return {"relative_position_bias_table"} - def check_image_size(self, x): return pad_to_multiple(x, self.window_size, mode="reflect")