Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MoSR #307

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar
- [DRCT](https://github.com/ming053l/DRCT)
- [PLKSR](https://github.com/dslisleedh/PLKSR) and [RealPLKSR](https://github.com/muslll/neosr/blob/master/neosr/archs/realplksr_arch.py) | [Models](https://drive.google.com/drive/u/1/folders/1lIkZ00y9cRQpLU9qmCIB2XtS-2ZoqKq8)
- [SeemoRe](https://github.com/eduardzamfir/seemoredetails) | [Models](https://drive.google.com/drive/folders/15jtvcS4jL_6QqEwaRodEN8FBrqVPrO2u?usp=share_link)
- [MoSR](https://github.com/umzi2/MoSR) | [Models](https://drive.google.com/drive/u/0/folders/1HPy7M4Zzq8oxhdsQ2cnfqy73klmQWp_r)

#### Face Restoration

Expand Down
2 changes: 2 additions & 0 deletions libs/spandrel/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
LaMa,
MixDehazeNet,
MMRealSR,
MoSR,
NAFNet,
OmniSR,
RealCUGAN,
Expand Down Expand Up @@ -86,4 +87,5 @@
ArchSupport.from_architecture(RetinexFormer.RetinexFormerArch()),
ArchSupport.from_architecture(HVICIDNet.HVICIDNetArch()),
ArchSupport.from_architecture(SeemoRe.SeemoReArch()),
ArchSupport.from_architecture(MoSR.MoSRArch()),
)
21 changes: 21 additions & 0 deletions libs/spandrel/spandrel/architectures/MoSR/__arch/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 umzi

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Empty file.
201 changes: 201 additions & 0 deletions libs/spandrel/spandrel/architectures/MoSR/__arch/mosr_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import torch
from torch import nn
from torch.nn.init import trunc_normal_

from spandrel.architectures.__arch_helpers.dysample import DySample
from spandrel.util import store_hyperparameters
from spandrel.util.timm.__drop import DropPath


class GPS(nn.Module):
"""Geo ensemble PixelShuffle"""

def __init__(
self,
dim,
scale,
out_ch=3,
# Own parameters
kernel_size: int = 3,
):
super().__init__()
self.in_to_k = nn.Conv2d(
dim, scale * scale * out_ch * 8, kernel_size, 1, kernel_size // 2
)
self.ps = nn.PixelShuffle(scale)

def forward(self, x):
rgb = self._geo_ensemble(x)
rgb = self.ps(rgb)
return rgb

def _geo_ensemble(self, x):
x = self.in_to_k(x)
x = x.reshape(x.shape[0], 8, -1, x.shape[-2], x.shape[-1])
x = x.mean(dim=1)
return x


class LayerNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.eps = eps

def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight[:, None, None] * x + self.bias[:, None, None]


class ConvBlock(nn.Module):
r"""https://github.com/joshyZhou/AST/blob/main/model.py#L22"""

def __init__(self, in_channel: int, out_channel: int, strides: int = 1):
super().__init__()
self.strides = strides
self.in_channel = in_channel
self.out_channel = out_channel
self.block = nn.Sequential(
nn.Conv2d(
in_channel, out_channel, kernel_size=3, stride=strides, padding=1
),
nn.Mish(),
nn.Conv2d(
out_channel, out_channel, kernel_size=3, stride=strides, padding=1
),
nn.Mish(),
)
self.conv11 = nn.Conv2d(
in_channel, out_channel, kernel_size=1, stride=strides, padding=0
)

def forward(self, x):
out1 = self.block(x)
out2 = self.conv11(x)
out = out1 + out2
return out


class GatedCNNBlock(nn.Module):
r"""
modernized mambaout main unit
https://github.com/yuweihao/MambaOut/blob/main/models/mambaout.py#L119
"""

def __init__(
self,
dim: int,
expansion_ratio: float = 8 / 3,
conv_ratio: float = 1.0,
kernel_size: int = 7,
drop_path: float = 0.5,
):
super().__init__()
self.norm = LayerNorm(dim)
hidden = int(expansion_ratio * dim)
self.fc1 = nn.Conv2d(dim, hidden * 2, 3, 1, 1)

self.act = nn.Mish()
conv_channels = int(conv_ratio * dim)
self.split_indices = [hidden, hidden - conv_channels, conv_channels]

self.conv = nn.Conv2d(
conv_channels,
conv_channels,
kernel_size,
1,
kernel_size // 2,
groups=conv_channels,
)
self.fc2 = nn.Conv2d(hidden, dim, 3, 1, 1)
self.drop_path = (
DropPath(drop_path)
if drop_path > 0.0 or not self.training
else nn.Identity()
)
self.apply(self._init_weights)

@staticmethod
def _init_weights(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, x):
shortcut = x
x = self.norm(x)
g, i, c = torch.split(self.fc1(x), self.split_indices, dim=1)
c = self.conv(c)
x = self.act(self.fc2(self.act(g) * torch.cat((i, c), dim=1)))
x = self.drop_path(x)
return x + (shortcut - 0.5)


@store_hyperparameters()
class MoSR(nn.Module):
"""Mamba Out Super-Resolution"""

hyperparameters = {}

def __init__(
self,
*,
in_ch: int = 3,
out_ch: int = 3,
upscale: int = 4,
n_block: int = 24,
dim: int = 64,
upsampler: str = "ps", # "ps" "dys" "gps"
drop_path: float = 0.0,
kernel_size: int = 7,
expansion_ratio: float = 1.5,
conv_ratio: float = 1.0,
):
super().__init__()
if upsampler in ["ps", "gps"]:
out_ch = in_ch
dp_rates = [x.item() for x in torch.linspace(0, drop_path, n_block)]
self.gblocks = nn.Sequential(
*[nn.Conv2d(in_ch, dim, 3, 1, 1)]
+ [
GatedCNNBlock(
dim=dim,
expansion_ratio=expansion_ratio,
kernel_size=kernel_size,
conv_ratio=conv_ratio,
drop_path=dp_rates[index],
)
for index in range(n_block)
]
+ [
nn.Conv2d(dim, dim * 2, 3, 1, 1),
nn.Mish(),
nn.Conv2d(dim * 2, dim, 3, 1, 1),
nn.Mish(),
nn.Conv2d(dim, dim, 1, 1),
]
)

self.shortcut = ConvBlock(in_ch, dim)

if upsampler == "ps":
self.upsampler = nn.Sequential(
nn.Conv2d(dim, out_ch * (upscale**2), 3, 1, 1), nn.PixelShuffle(upscale)
)
elif upsampler == "gps":
self.upsampler = GPS(dim, upscale, out_ch)
elif upsampler == "dys":
self.upsampler = DySample(dim, out_ch, upscale)
else:
raise ValueError(
f'upsampler: {upsampler} not supported, choose one of these options: \
["ps", "gps", "dys"]'
)

def forward(self, x):
x = self.gblocks(x) + (self.shortcut(x) - 0.5)
return self.upsampler(x)
97 changes: 97 additions & 0 deletions libs/spandrel/spandrel/architectures/MoSR/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import math

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len

from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict
from .__arch.mosr_arch import MoSR


class MoSRArch(Architecture[MoSR]):
def __init__(self) -> None:
super().__init__(
id="MoSR",
detect=KeyCondition.has_all(
"gblocks.0.weight",
"gblocks.0.bias",
"gblocks.1.norm.weight",
"gblocks.1.norm.bias",
"gblocks.1.fc1.weight",
"gblocks.1.fc1.bias",
"gblocks.1.conv.weight",
"gblocks.1.conv.bias",
"gblocks.1.fc2.weight",
"gblocks.1.fc2.bias",
),
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[MoSR]:
# default values
in_ch = 3
out_ch = 3
upscale = 4
n_block = 24
dim = 64
upsampler = "ps" # "ps" "dys", "gps"
drop_path = 0.0
kernel_size = 7
expansion_ratio = 1.5
conv_ratio = 1.0

n_block = get_seq_len(state_dict, "gblocks") - 6
in_ch = state_dict["gblocks.0.weight"].shape[1]
dim = state_dict["gblocks.0.weight"].shape[0]

# Calculate expansion ratio and convolution ratio
expansion_ratio = (
state_dict["gblocks.1.fc1.weight"].shape[0]
/ state_dict["gblocks.1.fc1.weight"].shape[1]
) / 2
conv_ratio = state_dict["gblocks.1.conv.weight"].shape[0] / dim
kernel_size = state_dict["gblocks.1.conv.weight"].shape[2]
# Determine upsampler type and calculate upscale
if "upsampler.init_pos" in state_dict:
upsampler = "dys"
out_ch = state_dict["upsampler.end_conv.weight"].shape[0]
upscale = math.isqrt(state_dict["upsampler.offset.weight"].shape[0] // 8)
elif "upsampler.in_to_k.weight" in state_dict:
upsampler = "gps"
out_ch = in_ch
upscale = math.isqrt(
state_dict["upsampler.in_to_k.weight"].shape[0] // 8 // out_ch
)
else:
upsampler = "ps"
out_ch = in_ch
upscale = math.isqrt(state_dict["upsampler.0.weight"].shape[0] // out_ch)

model = MoSR(
in_ch=in_ch,
out_ch=out_ch,
upscale=upscale,
n_block=n_block,
dim=dim,
upsampler=upsampler,
drop_path=drop_path,
kernel_size=kernel_size,
expansion_ratio=expansion_ratio,
conv_ratio=conv_ratio,
)

return ImageModelDescriptor(
model,
state_dict,
architecture=self,
purpose="SR",
tags=[],
supports_half=True,
supports_bfloat16=True,
scale=upscale,
input_channels=in_ch,
output_channels=out_ch,
)


__all__ = ["MoSRArch", "MoSR"]
37 changes: 37 additions & 0 deletions tests/__snapshots__/test_MoSR.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# serializer version: 1
# name: test_2x_nomos2_mosr
ImageModelDescriptor(
architecture=MoSRArch(
id='MoSR',
name='MoSR',
),
input_channels=3,
output_channels=3,
purpose='SR',
scale=2,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
]),
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_4x_nomos2_mosr
ImageModelDescriptor(
architecture=MoSRArch(
id='MoSR',
name='MoSR',
),
input_channels=3,
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
]),
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
Binary file added tests/images/outputs/16x16/2x_nomos2_mosr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/16x16/4x_nomos2_mosr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/32x32/2x_nomos2_mosr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/32x32/4x_nomos2_mosr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/64x64/2x_nomos2_mosr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/64x64/4x_nomos2_mosr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading