Skip to content

Commit

Permalink
Ensure that exported architecture class names match the arch ID/name (#…
Browse files Browse the repository at this point in the history
…284)

* Ensure that exported architecture class names match the arch ID/name

* Fixed tests
  • Loading branch information
RunDevelopment authored Jul 11, 2024
1 parent a6094e7 commit fb21c7d
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 55 deletions.
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/Compact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from spandrel.util import KeyCondition, get_scale_and_output_channels, get_seq_len

from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict
from .__arch.SRVGG import SRVGGNetCompact
from .__arch.SRVGG import SRVGGNetCompact as Compact


class CompactArch(Architecture[SRVGGNetCompact]):
class CompactArch(Architecture[Compact]):
def __init__(
self,
) -> None:
Expand All @@ -20,7 +20,7 @@ def __init__(
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRVGGNetCompact]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[Compact]:
state = state_dict

highest_num = get_seq_len(state, "body") - 1
Expand All @@ -32,7 +32,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRVGGNetCompact]:
pixelshuffle_shape = state[f"body.{highest_num}.bias"].shape[0]
scale, out_nc = get_scale_and_output_channels(pixelshuffle_shape, in_nc)

model = SRVGGNetCompact(
model = Compact(
num_in_ch=in_nc,
num_out_ch=out_nc,
num_feat=num_feat,
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/ESRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SizeRequirements,
StateDict,
)
from .__arch.RRDB import RRDBNet
from .__arch.RRDB import RRDBNet as ESRGAN


def _new_to_old_arch(state: StateDict, state_map: dict, num_blocks: int):
Expand Down Expand Up @@ -132,7 +132,7 @@ def _to_old_arch(state: StateDict) -> StateDict:
return _new_to_old_arch(state, state_map, num_blocks)


class ESRGANArch(Architecture[RRDBNet]):
class ESRGANArch(Architecture[ESRGAN]):
def __init__(self) -> None:
super().__init__(
id="ESRGAN",
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(self) -> None:
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[RRDBNet]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[ESRGAN]:
# default values
in_nc: int = 3
out_nc: int = 3
Expand Down Expand Up @@ -196,7 +196,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RRDBNet]:
else:
shuffle_factor = None

model = RRDBNet(
model = ESRGAN(
in_nc=in_nc,
out_nc=out_nc,
num_filters=num_filters,
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/GFPGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
SizeRequirements,
StateDict,
)
from .__arch.gfpganv1_clean_arch import GFPGANv1Clean
from .__arch.gfpganv1_clean_arch import GFPGANv1Clean as GFPGAN


class GFPGANArch(Architecture[GFPGANv1Clean]):
class GFPGANArch(Architecture[GFPGAN]):
def __init__(self) -> None:
super().__init__(
id="GFPGAN",
Expand All @@ -22,7 +22,7 @@ def __init__(self) -> None:
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGANv1Clean]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGAN]:
out_size = 512
num_style_feat = 512
channel_multiplier = 2
Expand All @@ -34,7 +34,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGANv1Clean]:
narrow = 1
sft_half = True

model = GFPGANv1Clean(
model = GFPGAN(
out_size=out_size,
num_style_feat=num_style_feat,
channel_multiplier=channel_multiplier,
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
SizeRequirements,
StateDict,
)
from .__arch.safmn_bcie import SAFMN_BCIE
from .__arch.safmn_bcie import SAFMN_BCIE as SAFMNBCIE


class SAFMNBCIEArch(Architecture[SAFMN_BCIE]):
class SAFMNBCIEArch(Architecture[SAFMNBCIE]):
def __init__(self) -> None:
super().__init__(
id="SAFMNBCIE",
Expand All @@ -37,7 +37,7 @@ def __init__(self) -> None:
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN_BCIE]:
def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMNBCIE]:
dim: int
n_blocks: int = 6
num_layers: int = 6
Expand All @@ -55,7 +55,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN_BCIE]:
hidden_dim = state_dict["feats.0.layers.0.ccm.ccm.0.weight"].shape[0]
ffn_scale = hidden_dim / dim

model = SAFMN_BCIE(
model = SAFMNBCIE(
dim=dim,
n_blocks=n_blocks,
num_layers=num_layers,
Expand Down
22 changes: 11 additions & 11 deletions tests/test_Compact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.Compact import CompactArch, SRVGGNetCompact
from spandrel.architectures.Compact import Compact, CompactArch

from .util import (
ModelFile,
Expand All @@ -16,14 +16,14 @@
def test_load():
assert_loads_correctly(
CompactArch(),
lambda: SRVGGNetCompact(),
lambda: SRVGGNetCompact(num_in_ch=1, num_out_ch=1),
lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3),
lambda: SRVGGNetCompact(num_in_ch=4, num_out_ch=4),
lambda: SRVGGNetCompact(num_in_ch=1, num_out_ch=3),
lambda: SRVGGNetCompact(num_feat=32),
lambda: SRVGGNetCompact(num_conv=5),
lambda: SRVGGNetCompact(upscale=3),
lambda: Compact(),
lambda: Compact(num_in_ch=1, num_out_ch=1),
lambda: Compact(num_in_ch=3, num_out_ch=3),
lambda: Compact(num_in_ch=4, num_out_ch=4),
lambda: Compact(num_in_ch=1, num_out_ch=3),
lambda: Compact(num_feat=32),
lambda: Compact(num_conv=5),
lambda: Compact(upscale=3),
)


Expand All @@ -45,7 +45,7 @@ def test_Compact_realesr_general_x4v3(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, SRVGGNetCompact)
assert isinstance(model.model, Compact)
assert_image_inference(
file,
model,
Expand All @@ -59,7 +59,7 @@ def test_Compact_community(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, SRVGGNetCompact)
assert isinstance(model.model, Compact)
assert_image_inference(
file,
model,
Expand Down
36 changes: 18 additions & 18 deletions tests/test_ESRGAN.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.ESRGAN import ESRGANArch, RRDBNet
from spandrel.architectures.ESRGAN import ESRGAN, ESRGANArch

from .util import (
ModelFile,
Expand All @@ -16,11 +16,11 @@
def test_load():
assert_loads_correctly(
ESRGANArch(),
lambda: RRDBNet(in_nc=3, out_nc=3, num_filters=64, num_blocks=23, scale=4),
lambda: RRDBNet(in_nc=1, out_nc=3, num_filters=32, num_blocks=11, scale=2),
lambda: RRDBNet(in_nc=1, out_nc=1, num_filters=64, num_blocks=23, scale=1),
lambda: RRDBNet(in_nc=4, out_nc=4, num_filters=64, num_blocks=23, scale=8),
lambda: RRDBNet(scale=4, plus=True),
lambda: ESRGAN(in_nc=3, out_nc=3, num_filters=64, num_blocks=23, scale=4),
lambda: ESRGAN(in_nc=1, out_nc=3, num_filters=32, num_blocks=11, scale=2),
lambda: ESRGAN(in_nc=1, out_nc=1, num_filters=64, num_blocks=23, scale=1),
lambda: ESRGAN(in_nc=4, out_nc=4, num_filters=64, num_blocks=23, scale=8),
lambda: ESRGAN(scale=4, plus=True),
)


Expand All @@ -47,7 +47,7 @@ def test_ESRGAN_community(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -61,7 +61,7 @@ def test_ESRGAN_community_2x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -75,7 +75,7 @@ def test_ESRGAN_community_4x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -89,7 +89,7 @@ def test_ESRGAN_community_8x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -103,7 +103,7 @@ def test_BSRGAN(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -117,7 +117,7 @@ def test_BSRGAN_2x(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -131,7 +131,7 @@ def test_RealSR_DPED(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -145,7 +145,7 @@ def test_RealSR_JPEG(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -159,7 +159,7 @@ def test_RealESRGAN_x4plus(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -173,7 +173,7 @@ def test_RealESRGAN_x2plus(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -187,7 +187,7 @@ def test_RealESRGAN_x4plus_anime_6B(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand All @@ -201,7 +201,7 @@ def test_RealESRNet_x4plus(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RRDBNet)
assert isinstance(model.model, ESRGAN)
assert_image_inference(
file,
model,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_GFPGAN.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.GFPGAN import GFPGANArch, GFPGANv1Clean
from spandrel.architectures.GFPGAN import GFPGAN, GFPGANArch
from tests.test_CodeFormer import assert_loads_correctly

from .util import (
Expand All @@ -15,7 +15,7 @@
def test_load():
assert_loads_correctly(
GFPGANArch(),
lambda: GFPGANv1Clean(),
lambda: GFPGAN(),
)


Expand All @@ -25,7 +25,7 @@ def test_GFPGAN_1_2(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GFPGANv1Clean)
assert isinstance(model.model, GFPGAN)


def test_GFPGAN_1_3(snapshot):
Expand All @@ -34,7 +34,7 @@ def test_GFPGAN_1_3(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GFPGANv1Clean)
assert isinstance(model.model, GFPGAN)


def test_GFPGAN_1_4(snapshot):
Expand All @@ -43,7 +43,7 @@ def test_GFPGAN_1_4(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GFPGANv1Clean)
assert isinstance(model.model, GFPGAN)
assert_image_inference(
model_file=file,
model=model,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_SAFMNBCIE.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.SAFMNBCIE import SAFMN_BCIE, SAFMNBCIEArch
from spandrel.architectures.SAFMNBCIE import SAFMNBCIE, SAFMNBCIEArch

from .util import (
ModelFile,
Expand All @@ -16,13 +16,13 @@
def test_load():
assert_loads_correctly(
SAFMNBCIEArch(),
lambda: SAFMN_BCIE(
lambda: SAFMNBCIE(
dim=36, n_blocks=8, num_layers=1, ffn_scale=2.0, upscaling_factor=4
),
lambda: SAFMN_BCIE(
lambda: SAFMNBCIE(
dim=36, n_blocks=8, num_layers=3, ffn_scale=3.0, upscaling_factor=3
),
lambda: SAFMN_BCIE(
lambda: SAFMNBCIE(
dim=8, n_blocks=3, num_layers=4, ffn_scale=5.0, upscaling_factor=2
),
)
Expand All @@ -41,5 +41,5 @@ def test_SAFMN_BCIE(snapshot):
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, SAFMN_BCIE)
assert isinstance(model.model, SAFMNBCIE)
assert_image_inference(file, model, [TestImage.JPEG_15])

0 comments on commit fb21c7d

Please sign in to comment.