Skip to content

Commit

Permalink
Add support for DCTLSA (#147)
Browse files Browse the repository at this point in the history
* Add support for DCTLSA

* New arch api

* finishing touches

* readme
  • Loading branch information
RunDevelopment authored Feb 23, 2024
1 parent e03da55 commit 211e0cd
Show file tree
Hide file tree
Showing 13 changed files with 721 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar
- [CRAFT](https://github.com/AVC2-UESTC/CRAFT-SR) | [Models](https://drive.google.com/file/d/13wAmc93BPeBUBQ24zUZOuUpdBFG2aAY5/view?usp=sharing)
- [SAFMN](https://github.com/sunny2109/SAFMN) | [Models](https://drive.google.com/drive/folders/12O_xgwfgc76DsYbiClYnl6ErCDrsi_S9?usp=share_link)
- [RGT](https://github.com/zhengchen1999/RGT) | [RGT Models](https://drive.google.com/drive/folders/1zxrr31Kp2D_N9a-OUAPaJEn_yTaSXTfZ?usp=drive_link), [RGT-S Models](https://drive.google.com/drive/folders/1j46WHs1Gvyif1SsZXKy1Y1IrQH0gfIQ1?usp=drive_link)
- [DCTLSA](https://github.com/zengkun301/DCTLSA) | [Models](https://github.com/zengkun301/DCTLSA/tree/main/pretrained)

#### Face Restoration

Expand Down
2 changes: 2 additions & 0 deletions src/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..architectures import (
CRAFT,
DAT,
DCTLSA,
DITN,
ESRGAN,
FBCNN,
Expand Down Expand Up @@ -73,6 +74,7 @@
ArchSupport.from_architecture(RealCUGAN.RealCUGANArch()),
ArchSupport.from_architecture(DDColor.DDColorArch()),
ArchSupport.from_architecture(SAFMN.SAFMNArch()),
ArchSupport.from_architecture(DCTLSA.DCTLSAArch()),
ArchSupport.from_architecture(FFTformer.FFTformerArch()),
ArchSupport.from_architecture(NAFNet.NAFNetArch()),
ArchSupport.from_architecture(M3SNet.M3SNetArch()),
Expand Down
84 changes: 84 additions & 0 deletions src/spandrel/architectures/DCTLSA/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing_extensions import override

from spandrel.util import KeyCondition, get_scale_and_output_channels

from ...__helpers.model_descriptor import (
Architecture,
ImageModelDescriptor,
SizeRequirements,
StateDict,
)
from .arch.dctlsa import DCTLSA


class DCTLSAArch(Architecture[DCTLSA]):
def __init__(self) -> None:
super().__init__(
id="DCTLSA",
detect=KeyCondition.has_all(
"fea_conv.weight",
"B1.body.0.transformer_body.0.blocks.0.attn.qkv.weight",
"B1.body.0.transformer_body.0.blocks.0.attn.local.pointwise_prenorm_1.weight",
"B1.body.1.transformer_body.0.blocks.0.attn.qkv.weight",
"B1.body.1.transformer_body.0.blocks.0.attn.local.pointwise_prenorm_1.weight",
"B6.body.0.transformer_body.0.blocks.0.attn.qkv.weight",
"B6.body.0.transformer_body.0.blocks.0.attn.local.pointwise_prenorm_1.weight",
"B6.body.1.transformer_body.0.blocks.0.attn.qkv.weight",
"B6.body.1.transformer_body.0.blocks.0.attn.local.pointwise_prenorm_1.weight",
"c.0.weight",
"c1.0.weight",
"c2.0.weight",
"c3.0.weight",
"c4.0.weight",
"c5.0.weight",
"LR_conv.weight",
"upsampler.0.weight",
),
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[DCTLSA]:
# defaults
in_nc = 3
nf = 55
num_modules = 6
out_nc = 3
upscale = 4
num_head = 5 # cannot be deduced from state dict

in_nc = state_dict["fea_conv.weight"].shape[1]
nf = state_dict["fea_conv.weight"].shape[0]
num_modules = state_dict["c.0.weight"].shape[1] // nf

# good old pixelshuffle
x = state_dict["upsampler.0.weight"].shape[0]
upscale, out_nc = get_scale_and_output_channels(x, in_nc)

model = DCTLSA(
in_nc=in_nc,
nf=nf,
num_modules=num_modules,
out_nc=out_nc,
upscale=upscale,
num_head=num_head,
)

tags = [
f"{nf}nf",
f"{num_modules}nm",
f"{num_head}nh",
]

return ImageModelDescriptor(
model,
state_dict,
architecture=self,
purpose="Restoration" if upscale == 1 else "SR",
tags=tags,
supports_half=False, # TODO: test
supports_bfloat16=True,
scale=upscale,
input_channels=in_nc,
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)
1 change: 1 addition & 0 deletions src/spandrel/architectures/DCTLSA/arch/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This project is released under the Apache 2.0 license.
Loading

0 comments on commit 211e0cd

Please sign in to comment.