diff --git a/README.md b/README.md index d7ed9147..db9f0003 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar - [Restormer](https://github.com/swz30/Restormer) (+) | [Models](https://github.com/swz30/Restormer/releases/tag/v1.0) - [FFTformer](https://github.com/kkkls/FFTformer) | [Models](https://github.com/kkkls/FFTformer/releases/tag/pretrain_model) - [M3SNet](https://github.com/Tombs98/M3SNet) (+) | [Models](https://drive.google.com/drive/folders/1y4BEX7LagtXVO98ZItSbJJl7WWM3gnbD) +- [MPRNet](https://github.com/swz30/MPRNet) (+) | [Deblurring](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing), [Deraining](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing), [Denoising](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing) #### DeJPEG diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py index 2192a922..63448b3e 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py @@ -7,6 +7,7 @@ DDColor, FeMaSR, M3SNet, + MPRNet, Restormer, SRFormer, ) @@ -22,4 +23,5 @@ ArchSupport.from_architecture(FeMaSR.FeMaSRArch()), ArchSupport.from_architecture(M3SNet.M3SNetArch()), ArchSupport.from_architecture(Restormer.RestormerArch()), + ArchSupport.from_architecture(MPRNet.MPRNetArch()), ) diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py new file mode 100644 index 00000000..ff4248e6 --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from typing_extensions import override + +from spandrel import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from spandrel.util import KeyCondition, get_seq_len + +from .arch.MPRNet import MPRNet + + +class MPRNetArch(Architecture[MPRNet]): + def __init__(self) -> None: + super().__init__( + id="MPRNet", + detect=KeyCondition.has_all( + "shallow_feat1.0.weight", + "shallow_feat1.1.CA.conv_du.0.weight", + "shallow_feat1.1.CA.conv_du.2.weight", + "shallow_feat1.1.body.0.weight", + "shallow_feat1.1.body.2.weight", + "shallow_feat2.0.weight", + "shallow_feat3.0.weight", + "stage1_encoder.encoder_level1.0.CA.conv_du.0.weight", + "stage1_encoder.encoder_level1.0.CA.conv_du.2.weight", + "stage1_encoder.encoder_level1.0.body.2.weight", + "stage1_encoder.encoder_level1.1.body.2.weight", + "stage1_encoder.encoder_level2.1.body.2.weight", + "stage1_encoder.encoder_level3.0.CA.conv_du.0.weight", + "stage1_decoder.decoder_level1.0.CA.conv_du.0.weight", + "stage1_decoder.decoder_level1.0.body.0.weight", + "stage1_decoder.decoder_level2.0.CA.conv_du.0.weight", + "stage1_decoder.decoder_level3.0.CA.conv_du.0.weight", + "stage1_decoder.skip_attn1.CA.conv_du.0.weight", + "stage1_decoder.skip_attn2.CA.conv_du.0.weight", + "stage1_decoder.up32.up.1.weight", + "stage2_encoder.encoder_level1.0.CA.conv_du.0.weight", + "stage2_decoder.decoder_level1.0.CA.conv_du.0.weight", + "sam12.conv1.weight", + "sam12.conv3.weight", + "sam23.conv3.weight", + "concat12.weight", + "concat23.weight", + "tail.weight", + "stage3_orsnet.orb1.body.0.CA.conv_du.0.weight", + "stage3_orsnet.orb1.body.0.CA.conv_du.2.weight", + "stage3_orsnet.orb1.body.0.body.0.weight", + "stage3_orsnet.orb1.body.0.body.2.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[MPRNet]: + # in_c: int = 3 + # out_c: int = 3 + # n_feat: int = 40 + # scale_unetfeats: int = 20 + # scale_orsnetfeats: int = 16 + # num_cab: int = 8 + # kernel_size: int = 3 + # reduction = 4 + # bias = False + + in_c = state_dict["shallow_feat1.0.weight"].shape[1] + n_feat = state_dict["shallow_feat1.0.weight"].shape[0] + kernel_size = state_dict["shallow_feat1.0.weight"].shape[2] + bias = "shallow_feat1.0.bias" in state_dict + reduction = n_feat // state_dict["shallow_feat1.1.CA.conv_du.0.weight"].shape[0] + + out_c = state_dict["tail.weight"].shape[0] + scale_orsnetfeats = state_dict["tail.weight"].shape[1] - n_feat + scale_unetfeats = ( + state_dict["stage1_encoder.encoder_level2.0.CA.conv_du.0.weight"].shape[1] + - n_feat + ) + + num_cab = get_seq_len(state_dict, "stage3_orsnet.orb1.body") - 1 + + model = MPRNet( + in_c=in_c, + out_c=out_c, + n_feat=n_feat, + scale_unetfeats=scale_unetfeats, + scale_orsnetfeats=scale_orsnetfeats, + num_cab=num_cab, + kernel_size=kernel_size, + reduction=reduction, + bias=bias, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[f"{n_feat}nf"], + supports_half=False, # TODO: verify + supports_bfloat16=True, + scale=1, + input_channels=in_c, + output_channels=out_c, + size_requirements=SizeRequirements(multiple_of=8), + call_fn=lambda model, x: model(x)[0], + ) diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/LICENSE b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/LICENSE new file mode 100644 index 00000000..7a331d39 --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/LICENSE @@ -0,0 +1,29 @@ +## ACADEMIC PUBLIC LICENSE + +### Permissions +:heavy_check_mark: Non-Commercial use +:heavy_check_mark: Modification +:heavy_check_mark: Distribution +:heavy_check_mark: Private use + +### Limitations +:x: Commercial Use +:x: Liability +:x: Warranty + +### Conditions +:information_source: License and copyright notice +:information_source: Same License + +MPRNet is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations. +You can use MPRNet in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately. + +You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software. +If you distribute verbatim or modified copies of this software, they must be distributed under this license. +This license guarantees that you're safe when using MPRNet in your work, for teaching or research. +This license guarantees that MPRNet will remain available free of charge for nonprofit use. +You can modify MPRNet to your purposes, and you can also share your modifications. + +If you would like to use MPRNet in commercial settings, contact us so we can discuss options. Send an email to waqas.zamir@inceptioniai.org + + diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/MPRNet.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/MPRNet.py new file mode 100644 index 00000000..2ef3e000 --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/MPRNet.py @@ -0,0 +1,545 @@ +""" +## Multi-Stage Progressive Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao +## https://arxiv.org/abs/2102.02808 +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +from spandrel.util import store_hyperparameters + + +########################################################################## +def conv(in_channels: int, out_channels: int, kernel_size: int, bias=False, stride=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + padding=(kernel_size // 2), + bias=bias, + stride=stride, + ) + + +########################################################################## +## Channel Attention Layer +class CALayer(nn.Module): + def __init__(self, channel: int, reduction: int = 16, bias=False): + super().__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), + nn.Sigmoid(), + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +########################################################################## +## Channel Attention Block (CAB) +class CAB(nn.Module): + def __init__( + self, n_feat: int, kernel_size: int, reduction: int, bias: bool, act: nn.Module + ): + super().__init__() + modules_body = [ + conv(n_feat, n_feat, kernel_size, bias=bias), + act, + conv(n_feat, n_feat, kernel_size, bias=bias), + ] + + self.CA = CALayer(n_feat, reduction, bias=bias) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res = self.CA(res) + res += x + return res + + +########################################################################## +## Supervised Attention Module +class SAM(nn.Module): + def __init__(self, n_feat: int, kernel_size: int, bias: bool): + super().__init__() + self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) + self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) + self.conv3 = conv(3, n_feat, kernel_size, bias=bias) + + def forward(self, x, x_img): + x1 = self.conv1(x) + img = self.conv2(x) + x_img + x2 = torch.sigmoid(self.conv3(img)) + x1 = x1 * x2 + x1 = x1 + x + return x1, img + + +########################################################################## +## U-Net + + +class Encoder(nn.Module): + def __init__( + self, + n_feat: int, + kernel_size: int, + reduction: int, + act: nn.Module, + bias: bool, + scale_unetfeats: int, + csff: bool, + ): + super().__init__() + + encoder_level1 = [ + CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2) + ] + encoder_level2 = [ + CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) + for _ in range(2) + ] + encoder_level3 = [ + CAB( + n_feat + (scale_unetfeats * 2), + kernel_size, + reduction, + bias=bias, + act=act, + ) + for _ in range(2) + ] + + self.encoder_level1 = nn.Sequential(*encoder_level1) + self.encoder_level2 = nn.Sequential(*encoder_level2) + self.encoder_level3 = nn.Sequential(*encoder_level3) + + self.down12 = DownSample(n_feat, scale_unetfeats) + self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats) + + # Cross Stage Feature Fusion (CSFF) + if csff: + self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) + self.csff_enc2 = nn.Conv2d( + n_feat + scale_unetfeats, + n_feat + scale_unetfeats, + kernel_size=1, + bias=bias, + ) + self.csff_enc3 = nn.Conv2d( + n_feat + (scale_unetfeats * 2), + n_feat + (scale_unetfeats * 2), + kernel_size=1, + bias=bias, + ) + + self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) + self.csff_dec2 = nn.Conv2d( + n_feat + scale_unetfeats, + n_feat + scale_unetfeats, + kernel_size=1, + bias=bias, + ) + self.csff_dec3 = nn.Conv2d( + n_feat + (scale_unetfeats * 2), + n_feat + (scale_unetfeats * 2), + kernel_size=1, + bias=bias, + ) + + def forward(self, x, encoder_outs=None, decoder_outs=None): + enc1 = self.encoder_level1(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc1 = ( + enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) + ) + + x = self.down12(enc1) + + enc2 = self.encoder_level2(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc2 = ( + enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) + ) + + x = self.down23(enc2) + + enc3 = self.encoder_level3(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc3 = ( + enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) + ) + + return [enc1, enc2, enc3] + + +class Decoder(nn.Module): + def __init__( + self, + n_feat: int, + kernel_size: int, + reduction: int, + act: nn.Module, + bias: bool, + scale_unetfeats: int, + ): + super().__init__() + + decoder_level1 = [ + CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2) + ] + decoder_level2 = [ + CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) + for _ in range(2) + ] + decoder_level3 = [ + CAB( + n_feat + (scale_unetfeats * 2), + kernel_size, + reduction, + bias=bias, + act=act, + ) + for _ in range(2) + ] + + self.decoder_level1 = nn.Sequential(*decoder_level1) + self.decoder_level2 = nn.Sequential(*decoder_level2) + self.decoder_level3 = nn.Sequential(*decoder_level3) + + self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) + self.skip_attn2 = CAB( + n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act + ) + + self.up21 = SkipUpSample(n_feat, scale_unetfeats) + self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats) + + def forward(self, outs): + enc1, enc2, enc3 = outs + dec3 = self.decoder_level3(enc3) + + x = self.up32(dec3, self.skip_attn2(enc2)) + dec2 = self.decoder_level2(x) + + x = self.up21(dec2, self.skip_attn1(enc1)) + dec1 = self.decoder_level1(x) + + return [dec1, dec2, dec3] + + +########################################################################## +##---------- Resizing Modules ---------- +class DownSample(nn.Module): + def __init__(self, in_channels: int, s_factor: int): + super().__init__() + self.down = nn.Sequential( + nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False), + nn.Conv2d( + in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False + ), + ) + + def forward(self, x): + x = self.down(x) + return x + + +class UpSample(nn.Module): + def __init__(self, in_channels: int, s_factor: int): + super().__init__() + self.up = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d( + in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False + ), + ) + + def forward(self, x): + x = self.up(x) + return x + + +class SkipUpSample(nn.Module): + def __init__(self, in_channels: int, s_factor: int): + super().__init__() + self.up = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d( + in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False + ), + ) + + def forward(self, x, y): + x = self.up(x) + x = x + y + return x + + +########################################################################## +## Original Resolution Block (ORB) +class ORB(nn.Module): + def __init__( + self, + n_feat: int, + kernel_size: int, + reduction: int, + act: nn.Module, + bias: bool, + num_cab: int, + ): + super().__init__() + modules_body: list[nn.Module] = [ + CAB(n_feat, kernel_size, reduction, bias=bias, act=act) + for _ in range(num_cab) + ] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +########################################################################## +class ORSNet(nn.Module): + def __init__( + self, + n_feat: int, + scale_orsnetfeats: int, + kernel_size: int, + reduction: int, + act: nn.Module, + bias: bool, + scale_unetfeats: int, + num_cab: int, + ): + super().__init__() + + self.orb1 = ORB( + n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab + ) + self.orb2 = ORB( + n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab + ) + self.orb3 = ORB( + n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab + ) + + self.up_enc1 = UpSample(n_feat, scale_unetfeats) + self.up_dec1 = UpSample(n_feat, scale_unetfeats) + + self.up_enc2 = nn.Sequential( + UpSample(n_feat + scale_unetfeats, scale_unetfeats), + UpSample(n_feat, scale_unetfeats), + ) + self.up_dec2 = nn.Sequential( + UpSample(n_feat + scale_unetfeats, scale_unetfeats), + UpSample(n_feat, scale_unetfeats), + ) + + self.conv_enc1 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_enc2 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_enc3 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + + self.conv_dec1 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_dec2 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + self.conv_dec3 = nn.Conv2d( + n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias + ) + + def forward(self, x, encoder_outs, decoder_outs): + x = self.orb1(x) + x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) + + x = self.orb2(x) + x = ( + x + + self.conv_enc2(self.up_enc1(encoder_outs[1])) + + self.conv_dec2(self.up_dec1(decoder_outs[1])) + ) + + x = self.orb3(x) + x = ( + x + + self.conv_enc3(self.up_enc2(encoder_outs[2])) + + self.conv_dec3(self.up_dec2(decoder_outs[2])) + ) + + return x + + +########################################################################## +@store_hyperparameters() +class MPRNet(nn.Module): + hyperparameters = {} + + def __init__( + self, + in_c: int = 3, + out_c: int = 3, + n_feat: int = 40, + scale_unetfeats: int = 20, + scale_orsnetfeats: int = 16, + num_cab: int = 8, + kernel_size: int = 3, + reduction=4, + bias=False, + ): + super().__init__() + + act = nn.PReLU() + self.shallow_feat1 = nn.Sequential( + conv(in_c, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act), + ) + self.shallow_feat2 = nn.Sequential( + conv(in_c, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act), + ) + self.shallow_feat3 = nn.Sequential( + conv(in_c, n_feat, kernel_size, bias=bias), + CAB(n_feat, kernel_size, reduction, bias=bias, act=act), + ) + + # Cross Stage Feature Fusion (CSFF) + self.stage1_encoder = Encoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False + ) + self.stage1_decoder = Decoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats + ) + + self.stage2_encoder = Encoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True + ) + self.stage2_decoder = Decoder( + n_feat, kernel_size, reduction, act, bias, scale_unetfeats + ) + + self.stage3_orsnet = ORSNet( + n_feat, + scale_orsnetfeats, + kernel_size, + reduction, + act, + bias, + scale_unetfeats, + num_cab, + ) + + self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) + self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) + + self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias) + self.concat23 = conv( + n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias + ) + self.tail = conv(n_feat + scale_orsnetfeats, out_c, kernel_size, bias=bias) + + def forward(self, x3_img): + # Original-resolution Image for Stage 3 + H = x3_img.size(2) + W = x3_img.size(3) + + # Multi-Patch Hierarchy: Split Image into four non-overlapping patches + + # Two Patches for Stage 2 + x2top_img = x3_img[:, :, 0 : int(H / 2), :] + x2bot_img = x3_img[:, :, int(H / 2) : H, :] + + # Four Patches for Stage 1 + x1ltop_img = x2top_img[:, :, :, 0 : int(W / 2)] + x1rtop_img = x2top_img[:, :, :, int(W / 2) : W] + x1lbot_img = x2bot_img[:, :, :, 0 : int(W / 2)] + x1rbot_img = x2bot_img[:, :, :, int(W / 2) : W] + + ##------------------------------------------- + ##-------------- Stage 1--------------------- + ##------------------------------------------- + ## Compute Shallow Features + x1ltop = self.shallow_feat1(x1ltop_img) + x1rtop = self.shallow_feat1(x1rtop_img) + x1lbot = self.shallow_feat1(x1lbot_img) + x1rbot = self.shallow_feat1(x1rbot_img) + + ## Process features of all 4 patches with Encoder of Stage 1 + feat1_ltop = self.stage1_encoder(x1ltop) + feat1_rtop = self.stage1_encoder(x1rtop) + feat1_lbot = self.stage1_encoder(x1lbot) + feat1_rbot = self.stage1_encoder(x1rbot) + + ## Concat deep features + feat1_top = [torch.cat((k, v), 3) for k, v in zip(feat1_ltop, feat1_rtop)] + feat1_bot = [torch.cat((k, v), 3) for k, v in zip(feat1_lbot, feat1_rbot)] + + ## Pass features through Decoder of Stage 1 + res1_top = self.stage1_decoder(feat1_top) + res1_bot = self.stage1_decoder(feat1_bot) + + ## Apply Supervised Attention Module (SAM) + x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) + x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) + + ## Output image at Stage 1 + stage1_img = torch.cat([stage1_img_top, stage1_img_bot], 2) + ##------------------------------------------- + ##-------------- Stage 2--------------------- + ##------------------------------------------- + ## Compute Shallow Features + x2top = self.shallow_feat2(x2top_img) + x2bot = self.shallow_feat2(x2bot_img) + + ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 + x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) + x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) + + ## Process features of both patches with Encoder of Stage 2 + feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) + feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) + + ## Concat deep features + feat2 = [torch.cat((k, v), 2) for k, v in zip(feat2_top, feat2_bot)] + + ## Pass features through Decoder of Stage 2 + res2 = self.stage2_decoder(feat2) + + ## Apply SAM + x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) + + ##------------------------------------------- + ##-------------- Stage 3--------------------- + ##------------------------------------------- + ## Compute Shallow Features + x3 = self.shallow_feat3(x3_img) + + ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 + x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) + + x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) + + stage3_img = self.tail(x3_cat) + + return stage3_img + x3_img, stage2_img, stage1_img diff --git a/tests/__snapshots__/test_MPRNet.ambr b/tests/__snapshots__/test_MPRNet.ambr new file mode 100644 index 00000000..90ca65a6 --- /dev/null +++ b/tests/__snapshots__/test_MPRNet.ambr @@ -0,0 +1,58 @@ +# serializer version: 1 +# name: test_deblurring + ImageModelDescriptor( + architecture=MPRNetArch( + id='MPRNet', + name='MPRNet', + ), + input_channels=3, + output_channels=3, + purpose='Restoration', + scale=1, + size_requirements=SizeRequirements(minimum=0, multiple_of=8, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '96nf', + ]), + tiling=, + ) +# --- +# name: test_denoising + ImageModelDescriptor( + architecture=MPRNetArch( + id='MPRNet', + name='MPRNet', + ), + input_channels=3, + output_channels=3, + purpose='Restoration', + scale=1, + size_requirements=SizeRequirements(minimum=0, multiple_of=8, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '80nf', + ]), + tiling=, + ) +# --- +# name: test_deraining + ImageModelDescriptor( + architecture=MPRNetArch( + id='MPRNet', + name='MPRNet', + ), + input_channels=3, + output_channels=3, + purpose='Restoration', + scale=1, + size_requirements=SizeRequirements(minimum=0, multiple_of=8, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '40nf', + ]), + tiling=, + ) +# --- diff --git a/tests/images/outputs/blurry-face/MPRNet_model_deblurring.png b/tests/images/outputs/blurry-face/MPRNet_model_deblurring.png new file mode 100644 index 00000000..bafbea1b Binary files /dev/null and b/tests/images/outputs/blurry-face/MPRNet_model_deblurring.png differ diff --git a/tests/test_MPRNet.py b/tests/test_MPRNet.py new file mode 100644 index 00000000..358c16a6 --- /dev/null +++ b/tests/test_MPRNet.py @@ -0,0 +1,72 @@ +from spandrel_extra_arches.architectures.MPRNet import MPRNet, MPRNetArch + +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, + skip_if_unchanged, +) + +skip_if_unchanged(__file__) + + +def test_load(): + assert_loads_correctly( + MPRNetArch(), + lambda: MPRNet(), + lambda: MPRNet(in_c=4, out_c=1), + lambda: MPRNet(n_feat=20), + lambda: MPRNet(kernel_size=5), + lambda: MPRNet(bias=True), + lambda: MPRNet(reduction=8), + lambda: MPRNet(scale_orsnetfeats=32), + lambda: MPRNet(scale_unetfeats=10), + lambda: MPRNet(num_cab=4), + check_safe_tensors=False, + ) + + +def test_size_requirements(): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view", + name="MPRNet_model_deblurring.pth", + ) + assert_size_requirements(file.load_model()) + + +def test_deblurring(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view", + name="MPRNet_model_deblurring.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, MPRNet) + assert_image_inference( + file, + model, + [TestImage.BLURRY_FACE], + ) + + +def test_deraining(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view", + name="MPRNet_model_deraining.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, MPRNet) + + +def test_denoising(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view", + name="MPRNet_model_denoising.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, MPRNet)