Skip to content

Commit

Permalink
Add support for NAFNet (#141)
Browse files Browse the repository at this point in the history
Co-authored-by: Joey Ballentine <34788790+joeyballentine@users.noreply.github.com>
  • Loading branch information
RunDevelopment and joeyballentine authored Jan 22, 2024
1 parent 6f17937 commit 6893851
Show file tree
Hide file tree
Showing 10 changed files with 431 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar
- [SCUNet](https://github.com/cszn/SCUNet) | [GAN Model](https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth) | [PSNR Model](https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth)
- [Uformer](https://github.com/ZhendongWang6/Uformer) | [Denoise SIDD Model](https://mailustceducn-my.sharepoint.com/:u:/g/personal/zhendongwang_mail_ustc_edu_cn/Ea7hMP82A0xFlOKPlQnBJy0B9gVP-1MJL75mR4QKBMGc2w?e=iOz0zz) | [Deblur GoPro Model](https://mailustceducn-my.sharepoint.com/:u:/g/personal/zhendongwang_mail_ustc_edu_cn/EfCPoTSEKJRAshoE6EAC_3YB7oNkbLUX6AUgWSCwoJe0oA?e=jai90x)
- [KBNet](https://github.com/zhangyi-3/KBNet) | [Models](https://mycuhk-my.sharepoint.com/personal/1155135732_link_cuhk_edu_hk/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F1155135732%5Flink%5Fcuhk%5Fedu%5Fhk%2FDocuments%2Fshare%2FKBNet%2FDenoising%2Fpretrained%5Fmodels)
- [NAFNet](https://github.com/megvii-research/NAFNet) | [Models](https://github.com/megvii-research/NAFNet#results-and-pre-trained-models)

#### DeJPEG

Expand Down
25 changes: 25 additions & 0 deletions src/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
KBNet,
LaMa,
MMRealSR,
NAFNet,
OmniSR,
RealCUGAN,
RestoreFormer,
Expand Down Expand Up @@ -362,6 +363,30 @@ def _detect(state_dict: StateDict) -> bool:
),
load=SAFMN.load,
),
ArchSupport(
id="NAFNet",
detect=_has_keys(
"intro.weight",
"ending.weight",
"ups.0.0.weight",
"downs.0.weight",
"middle_blks.0.beta",
"middle_blks.0.gamma",
"middle_blks.0.conv1.weight",
"middle_blks.0.conv2.weight",
"middle_blks.0.conv3.weight",
"middle_blks.0.sca.1.weight",
"middle_blks.0.conv4.weight",
"middle_blks.0.conv5.weight",
"middle_blks.0.norm1.weight",
"middle_blks.0.norm2.weight",
"encoders.0.0.beta",
"encoders.0.0.gamma",
"decoders.0.0.beta",
"decoders.0.0.gamma",
),
load=NAFNet.load,
),
ArchSupport(
id="ESRGAN",
detect=lambda state: (
Expand Down
46 changes: 46 additions & 0 deletions src/spandrel/architectures/NAFNet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from ...__helpers.model_descriptor import (
ImageModelDescriptor,
StateDict,
)
from ..__arch_helpers.state import get_seq_len
from .arch.NAFNet_arch import NAFNet


def load(state_dict: StateDict) -> ImageModelDescriptor[NAFNet]:
# default values
img_channel: int = 3
width: int = 16
middle_blk_num: int = 1
enc_blk_nums: list[int] = []
dec_blk_nums: list[int] = []

img_channel = state_dict["intro.weight"].shape[1]
width = state_dict["intro.weight"].shape[0]
middle_blk_num = get_seq_len(state_dict, "middle_blks")
for i in range(get_seq_len(state_dict, "encoders")):
enc_blk_nums.append(get_seq_len(state_dict, f"encoders.{i}"))
for i in range(get_seq_len(state_dict, "decoders")):
dec_blk_nums.append(get_seq_len(state_dict, f"decoders.{i}"))

model = NAFNet(
img_channel=img_channel,
width=width,
middle_blk_num=middle_blk_num,
enc_blk_nums=enc_blk_nums,
dec_blk_nums=dec_blk_nums,
)

return ImageModelDescriptor(
model,
state_dict,
architecture="NAFNet",
purpose="Restoration",
tags=[f"{width}w"],
supports_half=False, # TODO: Test this
supports_bfloat16=True,
scale=1,
input_channels=img_channel,
output_channels=img_channel,
)
21 changes: 21 additions & 0 deletions src/spandrel/architectures/NAFNet/arch/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 megvii-model

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
223 changes: 223 additions & 0 deletions src/spandrel/architectures/NAFNet/arch/NAFNet_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------

"""
Simple Baselines for Image Restoration
@article{chen2022simple,
title={Simple Baselines for Image Restoration},
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
journal={arXiv preprint arXiv:2204.04676},
year={2022}
}
"""
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from .arch_util import LayerNorm2d


class SimpleGate(nn.Module):
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
return x1 * x2


class NAFBlock(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(
in_channels=c,
out_channels=dw_channel,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
)
self.conv2 = nn.Conv2d(
in_channels=dw_channel,
out_channels=dw_channel,
kernel_size=3,
padding=1,
stride=1,
groups=dw_channel,
bias=True,
)
self.conv3 = nn.Conv2d(
in_channels=dw_channel // 2,
out_channels=c,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
)

# Simplified Channel Attention
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(
in_channels=dw_channel // 2,
out_channels=dw_channel // 2,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
),
)

# SimpleGate
self.sg = SimpleGate()

ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(
in_channels=c,
out_channels=ffn_channel,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
)
self.conv5 = nn.Conv2d(
in_channels=ffn_channel // 2,
out_channels=c,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
)

self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)

self.dropout1 = (
nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
)
self.dropout2 = (
nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
)

self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

def forward(self, inp):
x = inp

x = self.norm1(x)

x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = x * self.sca(x)
x = self.conv3(x)

x = self.dropout1(x)

y = inp + x * self.beta

x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)

x = self.dropout2(x)

return y + x * self.gamma


class NAFNet(nn.Module):
def __init__(
self,
img_channel: int = 3,
width: int = 16,
middle_blk_num: int = 1,
enc_blk_nums: list[int] = [],
dec_blk_nums: list[int] = [],
):
super().__init__()

self.intro = nn.Conv2d(
in_channels=img_channel,
out_channels=width,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias=True,
)
self.ending = nn.Conv2d(
in_channels=width,
out_channels=img_channel,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias=True,
)

self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.middle_blks = nn.ModuleList()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()

chan = width
for num in enc_blk_nums:
self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
chan = chan * 2

self.middle_blks = nn.Sequential(
*[NAFBlock(chan) for _ in range(middle_blk_num)]
)

for num in dec_blk_nums:
self.ups.append(
nn.Sequential(
nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2)
)
)
chan = chan // 2
self.decoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))

self.padder_size = 2 ** len(self.encoders)

def forward(self, inp):
_, _, H, W = inp.shape
inp = self.check_image_size(inp)

x = self.intro(inp)

encs = []

for encoder, down in zip(self.encoders, self.downs):
x = encoder(x)
encs.append(x)
x = down(x)

x = self.middle_blks(x)

for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
x = up(x)
x = x + enc_skip
x = decoder(x)

x = self.ending(x)
x = x + inp

return x[:, :, :H, :W]

def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
return x
44 changes: 44 additions & 0 deletions src/spandrel/architectures/NAFNet/arch/arch_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn as nn


class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps): # type: ignore
ctx.eps = eps
_N, C, _H, _W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y

@staticmethod
def backward(ctx, grad_output): # type: ignore
eps = ctx.eps

_N, C, _H, _W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)

mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return (
gx,
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
None,
)


class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super().__init__()
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
self.eps = eps

def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
Loading

0 comments on commit 6893851

Please sign in to comment.