Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that exported architecture class names match the arch ID/name #284

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/Compact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from spandrel.util import KeyCondition, get_scale_and_output_channels, get_seq_len

from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict
from .__arch.SRVGG import SRVGGNetCompact
from .__arch.SRVGG import SRVGGNetCompact as Compact


class CompactArch(Architecture[SRVGGNetCompact]):
class CompactArch(Architecture[Compact]):
def __init__(
self,
) -> None:
Expand All @@ -20,7 +20,7 @@ def __init__(
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRVGGNetCompact]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[Compact]:
state = state_dict

highest_num = get_seq_len(state, "body") - 1
Expand All @@ -32,7 +32,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRVGGNetCompact]:
pixelshuffle_shape = state[f"body.{highest_num}.bias"].shape[0]
scale, out_nc = get_scale_and_output_channels(pixelshuffle_shape, in_nc)

model = SRVGGNetCompact(
model = Compact(
num_in_ch=in_nc,
num_out_ch=out_nc,
num_feat=num_feat,
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/ESRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SizeRequirements,
StateDict,
)
from .__arch.RRDB import RRDBNet
from .__arch.RRDB import RRDBNet as ESRGAN


def _new_to_old_arch(state: StateDict, state_map: dict, num_blocks: int):
Expand Down Expand Up @@ -132,7 +132,7 @@ def _to_old_arch(state: StateDict) -> StateDict:
return _new_to_old_arch(state, state_map, num_blocks)


class ESRGANArch(Architecture[RRDBNet]):
class ESRGANArch(Architecture[ESRGAN]):
def __init__(self) -> None:
super().__init__(
id="ESRGAN",
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(self) -> None:
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[RRDBNet]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[ESRGAN]:
# default values
in_nc: int = 3
out_nc: int = 3
Expand Down Expand Up @@ -196,7 +196,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RRDBNet]:
else:
shuffle_factor = None

model = RRDBNet(
model = ESRGAN(
in_nc=in_nc,
out_nc=out_nc,
num_filters=num_filters,
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/GFPGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
SizeRequirements,
StateDict,
)
from .__arch.gfpganv1_clean_arch import GFPGANv1Clean
from .__arch.gfpganv1_clean_arch import GFPGANv1Clean as GFPGAN


class GFPGANArch(Architecture[GFPGANv1Clean]):
class GFPGANArch(Architecture[GFPGAN]):
def __init__(self) -> None:
super().__init__(
id="GFPGAN",
Expand All @@ -22,7 +22,7 @@ def __init__(self) -> None:
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGANv1Clean]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGAN]:
out_size = 512
num_style_feat = 512
channel_multiplier = 2
Expand All @@ -34,7 +34,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGANv1Clean]:
narrow = 1
sft_half = True

model = GFPGANv1Clean(
model = GFPGAN(
out_size=out_size,
num_style_feat=num_style_feat,
channel_multiplier=channel_multiplier,
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
SizeRequirements,
StateDict,
)
from .__arch.safmn_bcie import SAFMN_BCIE
from .__arch.safmn_bcie import SAFMN_BCIE as SAFMNBCIE


class SAFMNBCIEArch(Architecture[SAFMN_BCIE]):
class SAFMNBCIEArch(Architecture[SAFMNBCIE]):
def __init__(self) -> None:
super().__init__(
id="SAFMNBCIE",
Expand All @@ -37,7 +37,7 @@ def __init__(self) -> None:
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN_BCIE]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMNBCIE]:
dim: int
n_blocks: int = 6
num_layers: int = 6
Expand All @@ -55,7 +55,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN_BCIE]:
hidden_dim = state_dict["feats.0.layers.0.ccm.ccm.0.weight"].shape[0]
ffn_scale = hidden_dim / dim

model = SAFMN_BCIE(
model = SAFMNBCIE(
dim=dim,
n_blocks=n_blocks,
num_layers=num_layers,
Expand Down
22 changes: 11 additions & 11 deletions tests/test_Compact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.Compact import CompactArch, SRVGGNetCompact
from spandrel.architectures.Compact import Compact, CompactArch

from .util import (
ModelFile,
Expand All @@ -16,14 +16,14 @@
def test_load():
assert_loads_correctly(
CompactArch(),
lambda: SRVGGNetCompact(),
lambda: SRVGGNetCompact(num_in_ch=1, num_out_ch=1),
lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3),
lambda: SRVGGNetCompact(num_in_ch=4, num_out_ch=4),
lambda: SRVGGNetCompact(num_in_ch=1, num_out_ch=3),
lambda: SRVGGNetCompact(num_feat=32),
lambda: SRVGGNetCompact(num_conv=5),
lambda: SRVGGNetCompact(upscale=3),
lambda: Compact(),
lambda: Compact(num_in_ch=1, num_out_ch=1),
lambda: Compact(num_in_ch=3, num_out_ch=3),
lambda: Compact(num_in_ch=4, num_out_ch=4),
lambda: Compact(num_in_ch=1, num_out_ch=3),
lambda: Compact(num_feat=32),
lambda: Compact(num_conv=5),
lambda: Compact(upscale=3),
)


Expand All @@ -45,7 +45,7 @@ def test_Compact_realesr_general_x4v3(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, SRVGGNetCompact)
assert isinstance(model.model, Compact)
assert_image_inference(
file,
model,
Expand All @@ -59,7 +59,7 @@ def test_Compact_community(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, SRVGGNetCompact)
assert isinstance(model.model, Compact)
assert_image_inference(
file,
model,
Expand Down
36 changes: 18 additions & 18 deletions tests/test_ESRGAN.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.ESRGAN import ESRGANArch, RRDBNet
from spandrel.architectures.ESRGAN import ESRGAN, ESRGANArch

from .util import (
ModelFile,
Expand All @@ -16,11 +16,11 @@
def test_load():
assert_loads_correctly(
ESRGANArch(),
lambda: RRDBNet(in_nc=3, out_nc=3, num_filters=64, num_blocks=23, scale=4),
lambda: RRDBNet(in_nc=1, out_nc=3, num_filters=32, num_blocks=11, scale=2),
lambda: RRDBNet(in_nc=1, out_nc=1, num_filters=64, num_blocks=23, scale=1),
lambda: RRDBNet(in_nc=4, out_nc=4, num_filters=64, num_blocks=23, scale=8),
lambda: RRDBNet(scale=4, plus=True),
lambda: ESRGAN(in_nc=3, out_nc=3, num_filters=64, num_blocks=23, scale=4),
lambda: ESRGAN(in_nc=1, out_nc=3, num_filters=32, num_blocks=11, scale=2),
lambda: ESRGAN(in_nc=1, out_nc=1, num_filters=64, num_blocks=23, scale=1),
lambda: ESRGAN(in_nc=4, out_nc=4, num_filters=64, num_blocks=23, scale=8),
lambda: ESRGAN(scale=4, plus=True),
)


Expand All @@ -47,7 +47,7 @@ def test_ESRGAN_community(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -61,7 +61,7 @@ def test_ESRGAN_community_2x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -75,7 +75,7 @@ def test_ESRGAN_community_4x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -89,7 +89,7 @@ def test_ESRGAN_community_8x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -103,7 +103,7 @@ def test_BSRGAN(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -117,7 +117,7 @@ def test_BSRGAN_2x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -131,7 +131,7 @@ def test_RealSR_DPED(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -145,7 +145,7 @@ def test_RealSR_JPEG(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -159,7 +159,7 @@ def test_RealESRGAN_x4plus(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -173,7 +173,7 @@ def test_RealESRGAN_x2plus(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -187,7 +187,7 @@ def test_RealESRGAN_x4plus_anime_6B(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -201,7 +201,7 @@ def test_RealESRNet_x4plus(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_GFPGAN.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.GFPGAN import GFPGANArch, GFPGANv1Clean
from spandrel.architectures.GFPGAN import GFPGAN, GFPGANArch
from tests.test_CodeFormer import assert_loads_correctly

from .util import (
Expand All @@ -15,7 +15,7 @@
def test_load():
assert_loads_correctly(
GFPGANArch(),
lambda: GFPGANv1Clean(),
lambda: GFPGAN(),
)


Expand All @@ -25,7 +25,7 @@ def test_GFPGAN_1_2(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GFPGANv1Clean)
assert isinstance(model.model, GFPGAN)


def test_GFPGAN_1_3(snapshot):
Expand All @@ -34,7 +34,7 @@ def test_GFPGAN_1_3(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GFPGANv1Clean)
assert isinstance(model.model, GFPGAN)


def test_GFPGAN_1_4(snapshot):
Expand All @@ -43,7 +43,7 @@ def test_GFPGAN_1_4(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GFPGANv1Clean)
assert isinstance(model.model, GFPGAN)
assert_image_inference(
model_file=file,
model=model,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_SAFMNBCIE.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.SAFMNBCIE import SAFMN_BCIE, SAFMNBCIEArch
from spandrel.architectures.SAFMNBCIE import SAFMNBCIE, SAFMNBCIEArch

from .util import (
ModelFile,
Expand All @@ -16,13 +16,13 @@
def test_load():
assert_loads_correctly(
SAFMNBCIEArch(),
lambda: SAFMN_BCIE(
lambda: SAFMNBCIE(
dim=36, n_blocks=8, num_layers=1, ffn_scale=2.0, upscaling_factor=4
),
lambda: SAFMN_BCIE(
lambda: SAFMNBCIE(
dim=36, n_blocks=8, num_layers=3, ffn_scale=3.0, upscaling_factor=3
),
lambda: SAFMN_BCIE(
lambda: SAFMNBCIE(
dim=8, n_blocks=3, num_layers=4, ffn_scale=5.0, upscaling_factor=2
),
)
Expand All @@ -41,5 +41,5 @@ def test_SAFMN_BCIE(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, SAFMN_BCIE)
assert isinstance(model.model, SAFMNBCIE)
assert_image_inference(file, model, [TestImage.JPEG_15])
Loading