diff --git a/README.md b/README.md index 8d98d73..8cacdb3 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar - [DRCT](https://github.com/ming053l/DRCT) - [PLKSR](https://github.com/dslisleedh/PLKSR) and [RealPLKSR](https://github.com/muslll/neosr/blob/master/neosr/archs/realplksr_arch.py) | [Models](https://drive.google.com/drive/u/1/folders/1lIkZ00y9cRQpLU9qmCIB2XtS-2ZoqKq8) - [SeemoRe](https://github.com/eduardzamfir/seemoredetails) | [Models](https://drive.google.com/drive/folders/15jtvcS4jL_6QqEwaRodEN8FBrqVPrO2u?usp=share_link) +- [MoSR](https://github.com/umzi2/MoSR) | [Models](https://drive.google.com/drive/u/0/folders/1HPy7M4Zzq8oxhdsQ2cnfqy73klmQWp_r) #### Face Restoration diff --git a/libs/spandrel/spandrel/__helpers/main_registry.py b/libs/spandrel/spandrel/__helpers/main_registry.py index 190562d..b713469 100644 --- a/libs/spandrel/spandrel/__helpers/main_registry.py +++ b/libs/spandrel/spandrel/__helpers/main_registry.py @@ -27,6 +27,7 @@ LaMa, MixDehazeNet, MMRealSR, + MoSR, NAFNet, OmniSR, RealCUGAN, @@ -86,4 +87,5 @@ ArchSupport.from_architecture(RetinexFormer.RetinexFormerArch()), ArchSupport.from_architecture(HVICIDNet.HVICIDNetArch()), ArchSupport.from_architecture(SeemoRe.SeemoReArch()), + ArchSupport.from_architecture(MoSR.MoSRArch()), ) diff --git a/libs/spandrel/spandrel/architectures/MoSR/__arch/LICENSE b/libs/spandrel/spandrel/architectures/MoSR/__arch/LICENSE new file mode 100644 index 0000000..f4af5c8 --- /dev/null +++ b/libs/spandrel/spandrel/architectures/MoSR/__arch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 umzi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/libs/spandrel/spandrel/architectures/MoSR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/MoSR/__arch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/spandrel/spandrel/architectures/MoSR/__arch/mosr_arch.py b/libs/spandrel/spandrel/architectures/MoSR/__arch/mosr_arch.py new file mode 100644 index 0000000..c81a14c --- /dev/null +++ b/libs/spandrel/spandrel/architectures/MoSR/__arch/mosr_arch.py @@ -0,0 +1,201 @@ +import torch +from torch import nn +from torch.nn.init import trunc_normal_ + +from spandrel.architectures.__arch_helpers.dysample import DySample +from spandrel.util import store_hyperparameters +from spandrel.util.timm.__drop import DropPath + + +class GPS(nn.Module): + """Geo ensemble PixelShuffle""" + + def __init__( + self, + dim, + scale, + out_ch=3, + # Own parameters + kernel_size: int = 3, + ): + super().__init__() + self.in_to_k = nn.Conv2d( + dim, scale * scale * out_ch * 8, kernel_size, 1, kernel_size // 2 + ) + self.ps = nn.PixelShuffle(scale) + + def forward(self, x): + rgb = self._geo_ensemble(x) + rgb = self.ps(rgb) + return rgb + + def _geo_ensemble(self, x): + x = self.in_to_k(x) + x = x.reshape(x.shape[0], 8, -1, x.shape[-2], x.shape[-1]) + x = x.mean(dim=1) + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) + self.eps = eps + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + return self.weight[:, None, None] * x + self.bias[:, None, None] + + +class ConvBlock(nn.Module): + r"""https://github.com/joshyZhou/AST/blob/main/model.py#L22""" + + def __init__(self, in_channel: int, out_channel: int, strides: int = 1): + super().__init__() + self.strides = strides + self.in_channel = in_channel + self.out_channel = out_channel + self.block = nn.Sequential( + nn.Conv2d( + in_channel, out_channel, kernel_size=3, stride=strides, padding=1 + ), + nn.Mish(), + nn.Conv2d( + out_channel, out_channel, kernel_size=3, stride=strides, padding=1 + ), + nn.Mish(), + ) + self.conv11 = nn.Conv2d( + in_channel, out_channel, kernel_size=1, stride=strides, padding=0 + ) + + def forward(self, x): + out1 = self.block(x) + out2 = self.conv11(x) + out = out1 + out2 + return out + + +class GatedCNNBlock(nn.Module): + r""" + modernized mambaout main unit + https://github.com/yuweihao/MambaOut/blob/main/models/mambaout.py#L119 + """ + + def __init__( + self, + dim: int, + expansion_ratio: float = 8 / 3, + conv_ratio: float = 1.0, + kernel_size: int = 7, + drop_path: float = 0.5, + ): + super().__init__() + self.norm = LayerNorm(dim) + hidden = int(expansion_ratio * dim) + self.fc1 = nn.Conv2d(dim, hidden * 2, 3, 1, 1) + + self.act = nn.Mish() + conv_channels = int(conv_ratio * dim) + self.split_indices = [hidden, hidden - conv_channels, conv_channels] + + self.conv = nn.Conv2d( + conv_channels, + conv_channels, + kernel_size, + 1, + kernel_size // 2, + groups=conv_channels, + ) + self.fc2 = nn.Conv2d(hidden, dim, 3, 1, 1) + self.drop_path = ( + DropPath(drop_path) + if drop_path > 0.0 or not self.training + else nn.Identity() + ) + self.apply(self._init_weights) + + @staticmethod + def _init_weights(m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + shortcut = x + x = self.norm(x) + g, i, c = torch.split(self.fc1(x), self.split_indices, dim=1) + c = self.conv(c) + x = self.act(self.fc2(self.act(g) * torch.cat((i, c), dim=1))) + x = self.drop_path(x) + return x + (shortcut - 0.5) + + +@store_hyperparameters() +class MoSR(nn.Module): + """Mamba Out Super-Resolution""" + + hyperparameters = {} + + def __init__( + self, + *, + in_ch: int = 3, + out_ch: int = 3, + upscale: int = 4, + n_block: int = 24, + dim: int = 64, + upsampler: str = "ps", # "ps" "dys" "gps" + drop_path: float = 0.0, + kernel_size: int = 7, + expansion_ratio: float = 1.5, + conv_ratio: float = 1.0, + ): + super().__init__() + if upsampler in ["ps", "gps"]: + out_ch = in_ch + dp_rates = [x.item() for x in torch.linspace(0, drop_path, n_block)] + self.gblocks = nn.Sequential( + *[nn.Conv2d(in_ch, dim, 3, 1, 1)] + + [ + GatedCNNBlock( + dim=dim, + expansion_ratio=expansion_ratio, + kernel_size=kernel_size, + conv_ratio=conv_ratio, + drop_path=dp_rates[index], + ) + for index in range(n_block) + ] + + [ + nn.Conv2d(dim, dim * 2, 3, 1, 1), + nn.Mish(), + nn.Conv2d(dim * 2, dim, 3, 1, 1), + nn.Mish(), + nn.Conv2d(dim, dim, 1, 1), + ] + ) + + self.shortcut = ConvBlock(in_ch, dim) + + if upsampler == "ps": + self.upsampler = nn.Sequential( + nn.Conv2d(dim, out_ch * (upscale**2), 3, 1, 1), nn.PixelShuffle(upscale) + ) + elif upsampler == "gps": + self.upsampler = GPS(dim, upscale, out_ch) + elif upsampler == "dys": + self.upsampler = DySample(dim, out_ch, upscale) + else: + raise ValueError( + f'upsampler: {upsampler} not supported, choose one of these options: \ + ["ps", "gps", "dys"]' + ) + + def forward(self, x): + x = self.gblocks(x) + (self.shortcut(x) - 0.5) + return self.upsampler(x) diff --git a/libs/spandrel/spandrel/architectures/MoSR/__init__.py b/libs/spandrel/spandrel/architectures/MoSR/__init__.py new file mode 100644 index 0000000..246a5ef --- /dev/null +++ b/libs/spandrel/spandrel/architectures/MoSR/__init__.py @@ -0,0 +1,97 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict +from .__arch.mosr_arch import MoSR + + +class MoSRArch(Architecture[MoSR]): + def __init__(self) -> None: + super().__init__( + id="MoSR", + detect=KeyCondition.has_all( + "gblocks.0.weight", + "gblocks.0.bias", + "gblocks.1.norm.weight", + "gblocks.1.norm.bias", + "gblocks.1.fc1.weight", + "gblocks.1.fc1.bias", + "gblocks.1.conv.weight", + "gblocks.1.conv.bias", + "gblocks.1.fc2.weight", + "gblocks.1.fc2.bias", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[MoSR]: + # default values + in_ch = 3 + out_ch = 3 + upscale = 4 + n_block = 24 + dim = 64 + upsampler = "ps" # "ps" "dys", "gps" + drop_path = 0.0 + kernel_size = 7 + expansion_ratio = 1.5 + conv_ratio = 1.0 + + n_block = get_seq_len(state_dict, "gblocks") - 6 + in_ch = state_dict["gblocks.0.weight"].shape[1] + dim = state_dict["gblocks.0.weight"].shape[0] + + # Calculate expansion ratio and convolution ratio + expansion_ratio = ( + state_dict["gblocks.1.fc1.weight"].shape[0] + / state_dict["gblocks.1.fc1.weight"].shape[1] + ) / 2 + conv_ratio = state_dict["gblocks.1.conv.weight"].shape[0] / dim + kernel_size = state_dict["gblocks.1.conv.weight"].shape[2] + # Determine upsampler type and calculate upscale + if "upsampler.init_pos" in state_dict: + upsampler = "dys" + out_ch = state_dict["upsampler.end_conv.weight"].shape[0] + upscale = math.isqrt(state_dict["upsampler.offset.weight"].shape[0] // 8) + elif "upsampler.in_to_k.weight" in state_dict: + upsampler = "gps" + out_ch = in_ch + upscale = math.isqrt( + state_dict["upsampler.in_to_k.weight"].shape[0] // 8 // out_ch + ) + else: + upsampler = "ps" + out_ch = in_ch + upscale = math.isqrt(state_dict["upsampler.0.weight"].shape[0] // out_ch) + + model = MoSR( + in_ch=in_ch, + out_ch=out_ch, + upscale=upscale, + n_block=n_block, + dim=dim, + upsampler=upsampler, + drop_path=drop_path, + kernel_size=kernel_size, + expansion_ratio=expansion_ratio, + conv_ratio=conv_ratio, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="SR", + tags=[], + supports_half=True, + supports_bfloat16=True, + scale=upscale, + input_channels=in_ch, + output_channels=out_ch, + ) + + +__all__ = ["MoSRArch", "MoSR"] diff --git a/tests/__snapshots__/test_MoSR.ambr b/tests/__snapshots__/test_MoSR.ambr new file mode 100644 index 0000000..96ad119 --- /dev/null +++ b/tests/__snapshots__/test_MoSR.ambr @@ -0,0 +1,37 @@ +# serializer version: 1 +# name: test_2x_nomos2_mosr + ImageModelDescriptor( + architecture=MoSRArch( + id='MoSR', + name='MoSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=2, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + ]), + tiling=, + ) +# --- +# name: test_4x_nomos2_mosr + ImageModelDescriptor( + architecture=MoSRArch( + id='MoSR', + name='MoSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=4, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + ]), + tiling=, + ) +# --- diff --git a/tests/images/outputs/16x16/2x_nomos2_mosr.png b/tests/images/outputs/16x16/2x_nomos2_mosr.png new file mode 100644 index 0000000..3fdd140 Binary files /dev/null and b/tests/images/outputs/16x16/2x_nomos2_mosr.png differ diff --git a/tests/images/outputs/16x16/4x_nomos2_mosr.png b/tests/images/outputs/16x16/4x_nomos2_mosr.png new file mode 100644 index 0000000..662ab8b Binary files /dev/null and b/tests/images/outputs/16x16/4x_nomos2_mosr.png differ diff --git a/tests/images/outputs/32x32/2x_nomos2_mosr.png b/tests/images/outputs/32x32/2x_nomos2_mosr.png new file mode 100644 index 0000000..f9ca5e2 Binary files /dev/null and b/tests/images/outputs/32x32/2x_nomos2_mosr.png differ diff --git a/tests/images/outputs/32x32/4x_nomos2_mosr.png b/tests/images/outputs/32x32/4x_nomos2_mosr.png new file mode 100644 index 0000000..ac7de37 Binary files /dev/null and b/tests/images/outputs/32x32/4x_nomos2_mosr.png differ diff --git a/tests/images/outputs/64x64/2x_nomos2_mosr.png b/tests/images/outputs/64x64/2x_nomos2_mosr.png new file mode 100644 index 0000000..b09d810 Binary files /dev/null and b/tests/images/outputs/64x64/2x_nomos2_mosr.png differ diff --git a/tests/images/outputs/64x64/4x_nomos2_mosr.png b/tests/images/outputs/64x64/4x_nomos2_mosr.png new file mode 100644 index 0000000..5693df8 Binary files /dev/null and b/tests/images/outputs/64x64/4x_nomos2_mosr.png differ diff --git a/tests/test_MoSR.py b/tests/test_MoSR.py new file mode 100644 index 0000000..a89e968 --- /dev/null +++ b/tests/test_MoSR.py @@ -0,0 +1,70 @@ +from spandrel.architectures.MoSR import MoSR, MoSRArch + +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, +) + + +def test_load(): + assert_loads_correctly( + MoSRArch(), + lambda: MoSR(), + lambda: MoSR(in_ch=1, out_ch=1), + lambda: MoSR(n_block=5), + lambda: MoSR(dim=48), + lambda: MoSR(upscale=2), + lambda: MoSR(upsampler="dys"), + lambda: MoSR(upsampler="gps"), + lambda: MoSR(kernel_size=7), + lambda: MoSR(expansion_ratio=2.0), + lambda: MoSR(conv_ratio=2.0), + ) + + +def test_size_requirements(): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1zlpBQu74sguLCjvPpAL4p9t8QqcRktuL/view?usp=drive_link", + name="4x_nomos2_mosr.pth", + ) + assert_size_requirements(file.load_model()) + + file = ModelFile.from_url( + "https://drive.google.com/file/d/1_C9f6yHS-XZu0bSz3x9kvHy0M7gWHUWB/view?usp=drive_link", + name="4x_nomos2_mosr_t.pth", + ) + assert_size_requirements(file.load_model()) + + +def test_2x_nomos2_mosr(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1MkVu7lIAyrGc1Rb7ediDKoNDcBc6PAdF/view?usp=drive_link", + name="2x_nomos2_mosr.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, MoSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) + + +def test_4x_nomos2_mosr(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1zlpBQu74sguLCjvPpAL4p9t8QqcRktuL/view?usp=drive_link", + name="4x_nomos2_mosr.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, MoSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + )