diff --git a/README.md b/README.md index 3c3b292c..5e7b8623 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar - [KBNet](https://github.com/zhangyi-3/KBNet) | [Models](https://mycuhk-my.sharepoint.com/personal/1155135732_link_cuhk_edu_hk/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F1155135732%5Flink%5Fcuhk%5Fedu%5Fhk%2FDocuments%2Fshare%2FKBNet%2FDenoising%2Fpretrained%5Fmodels) - [NAFNet](https://github.com/megvii-research/NAFNet) | [Models](https://github.com/megvii-research/NAFNet#results-and-pre-trained-models) - [Restormer](https://github.com/swz30/Restormer) | [Models](https://github.com/swz30/Restormer/releases/tag/v1.0) +- [FFTformer](https://github.com/kkkls/FFTformer) | [Models](https://github.com/kkkls/FFTformer/releases/tag/pretrain_model) #### DeJPEG diff --git a/src/spandrel/__helpers/main_registry.py b/src/spandrel/__helpers/main_registry.py index a0afcc0f..a1880f2b 100644 --- a/src/spandrel/__helpers/main_registry.py +++ b/src/spandrel/__helpers/main_registry.py @@ -18,6 +18,7 @@ Compact, DDColor, FeMaSR, + FFTformer, KBNet, LaMa, MMRealSR, @@ -71,6 +72,7 @@ ArchSupport.from_architecture(RealCUGAN.RealCUGANArch()), ArchSupport.from_architecture(DDColor.DDColorArch()), ArchSupport.from_architecture(SAFMN.SAFMNArch()), + ArchSupport.from_architecture(FFTformer.FFTformerArch()), ArchSupport.from_architecture(NAFNet.NAFNetArch()), ArchSupport.from_architecture(Restormer.RestormerArch()), ArchSupport.from_architecture(ESRGAN.ESRGANArch()), diff --git a/src/spandrel/architectures/FFTformer/__init__.py b/src/spandrel/architectures/FFTformer/__init__.py new file mode 100644 index 00000000..3824aade --- /dev/null +++ b/src/spandrel/architectures/FFTformer/__init__.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .arch.fftformer_arch import FFTformer + + +class FFTformerArch(Architecture[FFTformer]): + def __init__(self) -> None: + super().__init__( + id="FFTformer", + detect=KeyCondition.has_all( + "patch_embed.proj.weight", + "encoder_level1.0.norm2.body.weight", + "encoder_level1.0.norm2.body.bias", + "encoder_level1.0.ffn.fft", + "encoder_level1.0.ffn.project_in.weight", + "encoder_level1.0.ffn.dwconv.weight", + "encoder_level1.0.ffn.project_out.weight", + "down1_2.body.1.weight", + "encoder_level2.0.ffn.fft", + "down2_3.body.1.weight", + "encoder_level3.0.ffn.fft", + "decoder_level3.0.attn.to_hidden.weight", + "decoder_level3.0.attn.norm.body.weight", + "up3_2.body.1.weight", + "reduce_chan_level2.weight", + "decoder_level2.0.attn.to_hidden.weight", + "up2_1.body.1.weight", + "decoder_level1.0.attn.to_hidden.weight", + "refinement.0.norm1.body.weight", + "refinement.0.attn.to_hidden.weight", + "refinement.0.ffn.fft", + "fuse2.att_channel.norm2.body.weight", + "fuse2.att_channel.ffn.fft", + "fuse2.conv.weight", + "fuse1.att_channel.norm2.body.weight", + "fuse1.att_channel.ffn.fft", + "fuse1.conv.weight", + "output.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[FFTformer]: + inp_channels = 3 + out_channels = 3 + dim = 48 + num_blocks = [6, 6, 12] + num_refinement_blocks = 4 + ffn_expansion_factor = 3 + bias = False + + inp_channels = state_dict["patch_embed.proj.weight"].shape[1] + out_channels = state_dict["output.weight"].shape[0] + dim = state_dict["patch_embed.proj.weight"].shape[0] + + num_blocks[0] = get_seq_len(state_dict, "encoder_level1") + num_blocks[1] = get_seq_len(state_dict, "encoder_level2") + num_blocks[2] = get_seq_len(state_dict, "encoder_level3") + + num_refinement_blocks = get_seq_len(state_dict, "refinement") + + # hidden_dim = int(dim * ffn_expansion_factor) + hidden_dim = state_dict["encoder_level1.0.ffn.project_out.weight"].shape[1] + ffn_expansion_factor = hidden_dim / dim + + bias = "encoder_level1.0.ffn.project_in.bias" in state_dict + + model = FFTformer( + inp_channels=inp_channels, + out_channels=out_channels, + dim=dim, + num_blocks=num_blocks, + num_refinement_blocks=num_refinement_blocks, + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[f"{dim}dim"], + supports_half=False, # TODO: verify + supports_bfloat16=True, + scale=1, + input_channels=inp_channels, + output_channels=out_channels, + size_requirements=SizeRequirements(multiple_of=32), + ) diff --git a/src/spandrel/architectures/FFTformer/arch/LICENSE b/src/spandrel/architectures/FFTformer/arch/LICENSE new file mode 100644 index 00000000..28eb0525 --- /dev/null +++ b/src/spandrel/architectures/FFTformer/arch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 kkkls + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/spandrel/architectures/FFTformer/arch/fftformer_arch.py b/src/spandrel/architectures/FFTformer/arch/fftformer_arch.py new file mode 100644 index 00000000..67fd6287 --- /dev/null +++ b/src/spandrel/architectures/FFTformer/arch/fftformer_arch.py @@ -0,0 +1,410 @@ +import numbers + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from spandrel.util import store_hyperparameters + + +def to_3d(x): + return rearrange(x, "b c h w -> b (h w) c") + + +def to_4d(x, h, w): + return rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) # type: ignore + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) # type: ignore + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super().__init__() + if LayerNorm_type == "BiasFree": + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +class DFFN(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super().__init__() + + hidden_features = int(dim * ffn_expansion_factor) + + self.patch_size = 8 + + self.dim = dim + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + ) + + self.fft = nn.Parameter( + torch.ones( + (hidden_features * 2, 1, 1, self.patch_size, self.patch_size // 2 + 1) + ) + ) + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x_patch = rearrange( + x, + "b c (h patch1) (w patch2) -> b c h w patch1 patch2", + patch1=self.patch_size, + patch2=self.patch_size, + ) + x_patch_fft = torch.fft.rfft2(x_patch.float()) + x_patch_fft = x_patch_fft * self.fft + x_patch = torch.fft.irfft2(x_patch_fft, s=(self.patch_size, self.patch_size)) + x = rearrange( + x_patch, + "b c h w patch1 patch2 -> b c (h patch1) (w patch2)", + patch1=self.patch_size, + patch2=self.patch_size, + ) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +class FSAS(nn.Module): + def __init__(self, dim, bias): + super().__init__() + + self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias) + self.to_hidden_dw = nn.Conv2d( + dim * 6, + dim * 6, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 6, + bias=bias, + ) + + self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias) + + self.norm = LayerNorm(dim * 2, LayerNorm_type="WithBias") + + self.patch_size = 8 + + def forward(self, x): + hidden = self.to_hidden(x) + + q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1) + + q_patch = rearrange( + q, + "b c (h patch1) (w patch2) -> b c h w patch1 patch2", + patch1=self.patch_size, + patch2=self.patch_size, + ) + k_patch = rearrange( + k, + "b c (h patch1) (w patch2) -> b c h w patch1 patch2", + patch1=self.patch_size, + patch2=self.patch_size, + ) + q_fft = torch.fft.rfft2(q_patch.float()) + k_fft = torch.fft.rfft2(k_patch.float()) + + out = q_fft * k_fft + out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size)) + out = rearrange( + out, + "b c h w patch1 patch2 -> b c (h patch1) (w patch2)", + patch1=self.patch_size, + patch2=self.patch_size, + ) + + out = self.norm(out) + + output = v * out + output = self.project_out(output) + + return output + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + ffn_expansion_factor=2.66, + bias=False, + LayerNorm_type="WithBias", + att=False, + ): + super().__init__() + + self.att = att + if self.att: + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = FSAS(dim, bias) + + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = DFFN(dim, ffn_expansion_factor, bias) + + def forward(self, x): + if self.att: + x = x + self.attn(self.norm1(x)) + + x = x + self.ffn(self.norm2(x)) + + return x + + +class Fuse(nn.Module): + def __init__(self, n_feat): + super().__init__() + self.n_feat = n_feat + self.att_channel = TransformerBlock(dim=n_feat * 2) + + self.conv = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0) + self.conv2 = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0) + + def forward(self, enc, dnc): + x = self.conv(torch.cat((enc, dnc), dim=1)) + x = self.att_channel(x) + x = self.conv2(x) + e, d = torch.split(x, [self.n_feat, self.n_feat], dim=1) + output = e + d + + return output + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super().__init__() + + self.proj = nn.Conv2d( + in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias + ) + + def forward(self, x): + x = self.proj(x) + + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super().__init__() + + self.body = nn.Sequential( + nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False), + nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False), + ) + + def forward(self, x): + return self.body(x) + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super().__init__() + + self.body = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False), + ) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- FFTformer ----------------------- +@store_hyperparameters() +class FFTformer(nn.Module): + hyperparameters = {} + + def __init__( + self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[6, 6, 12], + num_refinement_blocks=4, + ffn_expansion_factor=3, + bias=False, + ): + super().__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential( + *[ + TransformerBlock( + dim=dim, ffn_expansion_factor=ffn_expansion_factor, bias=bias + ) + for _ in range(num_blocks[0]) + ] + ) + + self.down1_2 = Downsample(dim) + self.encoder_level2 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**1), + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + ) + for _ in range(num_blocks[1]) + ] + ) + + self.down2_3 = Downsample(int(dim * 2**1)) + self.encoder_level3 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**2), + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + ) + for _ in range(num_blocks[2]) + ] + ) + + self.decoder_level3 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**2), + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + att=True, + ) + for _ in range(num_blocks[2]) + ] + ) + + self.up3_2 = Upsample(int(dim * 2**2)) + self.reduce_chan_level2 = nn.Conv2d( + int(dim * 2**2), int(dim * 2**1), kernel_size=1, bias=bias + ) + self.decoder_level2 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**1), + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + att=True, + ) + for _ in range(num_blocks[1]) + ] + ) + + self.up2_1 = Upsample(int(dim * 2**1)) + + self.decoder_level1 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim), + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + att=True, + ) + for _ in range(num_blocks[0]) + ] + ) + + self.refinement = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim), + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + att=True, + ) + for _ in range(num_refinement_blocks) + ] + ) + + self.fuse2 = Fuse(dim * 2) + self.fuse1 = Fuse(dim) + self.output = nn.Conv2d( + int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias + ) + + def forward(self, inp_img): + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + out_dec_level3 = self.decoder_level3(out_enc_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + + inp_dec_level2 = self.fuse2(inp_dec_level2, out_enc_level2) + + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + + inp_dec_level1 = self.fuse1(inp_dec_level1, out_enc_level1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 diff --git a/tests/__snapshots__/test_FFTformer.ambr b/tests/__snapshots__/test_FFTformer.ambr new file mode 100644 index 00000000..539cbf47 --- /dev/null +++ b/tests/__snapshots__/test_FFTformer.ambr @@ -0,0 +1,20 @@ +# serializer version: 1 +# name: test_fftformer_GoPro + ImageModelDescriptor( + architecture=FFTformerArch( + id='FFTformer', + name='FFTformer', + ), + input_channels=3, + output_channels=3, + purpose='Restoration', + scale=1, + size_requirements=SizeRequirements(minimum=0, multiple_of=32, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '48dim', + ]), + tiling=, + ) +# --- diff --git a/tests/test_FFTformer.py b/tests/test_FFTformer.py new file mode 100644 index 00000000..f150cdf9 --- /dev/null +++ b/tests/test_FFTformer.py @@ -0,0 +1,37 @@ +from spandrel.architectures.FFTformer import FFTformer, FFTformerArch + +from .util import ( + ModelFile, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, +) + + +def test_load(): + assert_loads_correctly( + FFTformerArch(), + lambda: FFTformer(), + lambda: FFTformer(dim=64, inp_channels=4, out_channels=1), + lambda: FFTformer(num_blocks=[3, 5, 7], ffn_expansion_factor=2), + lambda: FFTformer(num_refinement_blocks=2), + lambda: FFTformer(bias=True), + ) + + +def test_size_requirements(): + file = ModelFile.from_url( + "https://github.com/kkkls/FFTformer/releases/download/pretrain_model/fftformer_GoPro.pth", + name="fftformer_GoPro.pth", + ) + assert_size_requirements(file.load_model()) + + +def test_fftformer_GoPro(snapshot): + file = ModelFile.from_url( + "https://github.com/kkkls/FFTformer/releases/download/pretrain_model/fftformer_GoPro.pth", + name="fftformer_GoPro.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, FFTformer)