diff --git a/README.md b/README.md index 322f16ad..c7598d8c 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar - [SCUNet](https://github.com/cszn/SCUNet) | [GAN Model](https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth) | [PSNR Model](https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth) - [Uformer](https://github.com/ZhendongWang6/Uformer) | [Denoise SIDD Model](https://mailustceducn-my.sharepoint.com/:u:/g/personal/zhendongwang_mail_ustc_edu_cn/Ea7hMP82A0xFlOKPlQnBJy0B9gVP-1MJL75mR4QKBMGc2w?e=iOz0zz) | [Deblur GoPro Model](https://mailustceducn-my.sharepoint.com/:u:/g/personal/zhendongwang_mail_ustc_edu_cn/EfCPoTSEKJRAshoE6EAC_3YB7oNkbLUX6AUgWSCwoJe0oA?e=jai90x) - [KBNet](https://github.com/zhangyi-3/KBNet) | [Models](https://mycuhk-my.sharepoint.com/personal/1155135732_link_cuhk_edu_hk/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F1155135732%5Flink%5Fcuhk%5Fedu%5Fhk%2FDocuments%2Fshare%2FKBNet%2FDenoising%2Fpretrained%5Fmodels) +- [NAFNet](https://github.com/megvii-research/NAFNet) | [Models](https://github.com/megvii-research/NAFNet#results-and-pre-trained-models) #### DeJPEG diff --git a/src/spandrel/__helpers/main_registry.py b/src/spandrel/__helpers/main_registry.py index 9f6ddf0c..70012207 100644 --- a/src/spandrel/__helpers/main_registry.py +++ b/src/spandrel/__helpers/main_registry.py @@ -22,6 +22,7 @@ KBNet, LaMa, MMRealSR, + NAFNet, OmniSR, RealCUGAN, RestoreFormer, @@ -362,6 +363,30 @@ def _detect(state_dict: StateDict) -> bool: ), load=SAFMN.load, ), + ArchSupport( + id="NAFNet", + detect=_has_keys( + "intro.weight", + "ending.weight", + "ups.0.0.weight", + "downs.0.weight", + "middle_blks.0.beta", + "middle_blks.0.gamma", + "middle_blks.0.conv1.weight", + "middle_blks.0.conv2.weight", + "middle_blks.0.conv3.weight", + "middle_blks.0.sca.1.weight", + "middle_blks.0.conv4.weight", + "middle_blks.0.conv5.weight", + "middle_blks.0.norm1.weight", + "middle_blks.0.norm2.weight", + "encoders.0.0.beta", + "encoders.0.0.gamma", + "decoders.0.0.beta", + "decoders.0.0.gamma", + ), + load=NAFNet.load, + ), ArchSupport( id="ESRGAN", detect=lambda state: ( diff --git a/src/spandrel/architectures/NAFNet/__init__.py b/src/spandrel/architectures/NAFNet/__init__.py new file mode 100644 index 00000000..c3ad5678 --- /dev/null +++ b/src/spandrel/architectures/NAFNet/__init__.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from ...__helpers.model_descriptor import ( + ImageModelDescriptor, + StateDict, +) +from ..__arch_helpers.state import get_seq_len +from .arch.NAFNet_arch import NAFNet + + +def load(state_dict: StateDict) -> ImageModelDescriptor[NAFNet]: + # default values + img_channel: int = 3 + width: int = 16 + middle_blk_num: int = 1 + enc_blk_nums: list[int] = [] + dec_blk_nums: list[int] = [] + + img_channel = state_dict["intro.weight"].shape[1] + width = state_dict["intro.weight"].shape[0] + middle_blk_num = get_seq_len(state_dict, "middle_blks") + for i in range(get_seq_len(state_dict, "encoders")): + enc_blk_nums.append(get_seq_len(state_dict, f"encoders.{i}")) + for i in range(get_seq_len(state_dict, "decoders")): + dec_blk_nums.append(get_seq_len(state_dict, f"decoders.{i}")) + + model = NAFNet( + img_channel=img_channel, + width=width, + middle_blk_num=middle_blk_num, + enc_blk_nums=enc_blk_nums, + dec_blk_nums=dec_blk_nums, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture="NAFNet", + purpose="Restoration", + tags=[f"{width}w"], + supports_half=False, # TODO: Test this + supports_bfloat16=True, + scale=1, + input_channels=img_channel, + output_channels=img_channel, + ) diff --git a/src/spandrel/architectures/NAFNet/arch/LICENSE b/src/spandrel/architectures/NAFNet/arch/LICENSE new file mode 100644 index 00000000..94111063 --- /dev/null +++ b/src/spandrel/architectures/NAFNet/arch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 megvii-model + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/spandrel/architectures/NAFNet/arch/NAFNet_arch.py b/src/spandrel/architectures/NAFNet/arch/NAFNet_arch.py new file mode 100644 index 00000000..4c3d7062 --- /dev/null +++ b/src/spandrel/architectures/NAFNet/arch/NAFNet_arch.py @@ -0,0 +1,223 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Simple Baselines for Image Restoration + +@article{chen2022simple, + title={Simple Baselines for Image Restoration}, + author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + journal={arXiv preprint arXiv:2204.04676}, + year={2022} +} +""" +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .arch_util import LayerNorm2d + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=dw_channel // 2, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class NAFNet(nn.Module): + def __init__( + self, + img_channel: int = 3, + width: int = 16, + middle_blk_num: int = 1, + enc_blk_nums: list[int] = [], + dec_blk_nums: list[int] = [], + ): + super().__init__() + + self.intro = nn.Conv2d( + in_channels=img_channel, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=img_channel, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) + self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2)) + chan = chan * 2 + + self.middle_blks = nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp): + _, _, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x diff --git a/src/spandrel/architectures/NAFNet/arch/arch_util.py b/src/spandrel/architectures/NAFNet/arch/arch_util.py new file mode 100644 index 00000000..28c08387 --- /dev/null +++ b/src/spandrel/architectures/NAFNet/arch/arch_util.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): # type: ignore + ctx.eps = eps + _N, C, _H, _W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): # type: ignore + eps = ctx.eps + + _N, C, _H, _W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super().__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) diff --git a/tests/__snapshots__/test_NAFNet.ambr b/tests/__snapshots__/test_NAFNet.ambr new file mode 100644 index 00000000..fd8b717a --- /dev/null +++ b/tests/__snapshots__/test_NAFNet.ambr @@ -0,0 +1,17 @@ +# serializer version: 1 +# name: test_NAFNet_GoPro_width32 + ImageModelDescriptor( + architecture='NAFNet', + input_channels=3, + output_channels=3, + purpose='Restoration', + scale=1, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '32w', + ]), + tiling=, + ) +# --- diff --git a/tests/images/outputs/32x32/NAFNet-GoPro-width32.png b/tests/images/outputs/32x32/NAFNet-GoPro-width32.png new file mode 100644 index 00000000..7365ff92 Binary files /dev/null and b/tests/images/outputs/32x32/NAFNet-GoPro-width32.png differ diff --git a/tests/images/outputs/blurry-face/NAFNet-GoPro-width32.png b/tests/images/outputs/blurry-face/NAFNet-GoPro-width32.png new file mode 100644 index 00000000..c7f5df3e Binary files /dev/null and b/tests/images/outputs/blurry-face/NAFNet-GoPro-width32.png differ diff --git a/tests/test_NAFNet.py b/tests/test_NAFNet.py new file mode 100644 index 00000000..c50cf680 --- /dev/null +++ b/tests/test_NAFNet.py @@ -0,0 +1,54 @@ +from spandrel.architectures.NAFNet import NAFNet, load + +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, +) + + +def test_load(): + assert_loads_correctly( + load, + lambda: NAFNet( + img_channel=3, + width=32, + middle_blk_num=12, + enc_blk_nums=[2, 2, 4, 8], + dec_blk_nums=[2, 2, 2, 2], + ), + lambda: NAFNet( + img_channel=3, + width=32, + middle_blk_num=1, + enc_blk_nums=[1, 1, 1, 28], + dec_blk_nums=[1, 1, 1, 1], + ), + condition=lambda a, b: (a.padder_size == b.padder_size), + ) + + +def test_size_requirements(): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1Fr2QadtDCEXg6iwWX8OzeZLbHOx2t5Bj/view", + name="NAFNet-GoPro-width32.pth", + ) + assert_size_requirements(file.load_model()) + + +def test_NAFNet_GoPro_width32(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1Fr2QadtDCEXg6iwWX8OzeZLbHOx2t5Bj/view", + name="NAFNet-GoPro-width32.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, NAFNet) + assert_image_inference( + file, + model, + [TestImage.BLURRY_FACE, TestImage.SR_32], + )