From 5a57f58797e353fa54a58cd617a96553712366aa Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Mon, 6 May 2024 17:08:40 +0200 Subject: [PATCH] Unify logic for detecting pixelshuffle --- .../spandrel/architectures/ATD/__init__.py | 12 ++----- .../spandrel/architectures/DAT/__init__.py | 8 ++--- .../spandrel/architectures/DRCT/__init__.py | 21 ++---------- .../spandrel/architectures/GRL/__init__.py | 13 +++---- .../spandrel/architectures/HAT/__init__.py | 7 ++-- .../spandrel/architectures/RGT/__init__.py | 10 ++---- .../architectures/Swin2SR/__init__.py | 8 ++--- .../spandrel/architectures/SwinIR/__init__.py | 12 ++----- libs/spandrel/spandrel/util/__init__.py | 34 +++++++++++++++++-- .../architectures/SRFormer/__init__.py | 9 ++--- 10 files changed, 55 insertions(+), 79 deletions(-) diff --git a/libs/spandrel/spandrel/architectures/ATD/__init__.py b/libs/spandrel/spandrel/architectures/ATD/__init__.py index 23a6c0cb..2bfe594f 100644 --- a/libs/spandrel/spandrel/architectures/ATD/__init__.py +++ b/libs/spandrel/spandrel/architectures/ATD/__init__.py @@ -2,7 +2,7 @@ from typing_extensions import override -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from ...__helpers.model_descriptor import ( Architecture, @@ -114,15 +114,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]: upscale = 4 elif "conv_before_upsample.0.weight" in state_dict: upsampler = "pixelshuffle" - upscale = 1 - for i in range(0, 10, 2): - if f"upsample.{i}.weight" not in state_dict: - break - num_feat = state_dict[f"upsample.{i}.weight"].shape[1] - - upscale *= math.isqrt( - state_dict[f"upsample.{i}.weight"].shape[0] // num_feat - ) + upscale, _ = get_pixelshuffle_params(state_dict, "upsample") elif "conv_last.weight" in state_dict: upsampler = "" upscale = 1 diff --git a/libs/spandrel/spandrel/architectures/DAT/__init__.py b/libs/spandrel/spandrel/architectures/DAT/__init__.py index 7f87a524..e4139991 100644 --- a/libs/spandrel/spandrel/architectures/DAT/__init__.py +++ b/libs/spandrel/spandrel/architectures/DAT/__init__.py @@ -2,7 +2,7 @@ from typing_extensions import override -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from ...__helpers.model_descriptor import ( Architecture, @@ -107,11 +107,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]: resi_connection = "1conv" if "conv_after_body.weight" in state_dict else "3conv" if upsampler == "pixelshuffle": - upscale = 1 - for i in range(0, get_seq_len(state_dict, "upsample"), 2): - num_feat = state_dict[f"upsample.{i}.weight"].shape[1] - shape = state_dict[f"upsample.{i}.weight"].shape[0] - upscale *= int(math.sqrt(shape // num_feat)) + upscale, num_feat = get_pixelshuffle_params(state_dict, "upsample") elif upsampler == "pixelshuffledirect": num_feat = state_dict["upsample.0.weight"].shape[1] upscale = int( diff --git a/libs/spandrel/spandrel/architectures/DRCT/__init__.py b/libs/spandrel/spandrel/architectures/DRCT/__init__.py index c03cb217..64230dc9 100644 --- a/libs/spandrel/spandrel/architectures/DRCT/__init__.py +++ b/libs/spandrel/spandrel/architectures/DRCT/__init__.py @@ -2,7 +2,7 @@ from typing_extensions import override -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from ...__helpers.model_descriptor import ( Architecture, @@ -13,23 +13,6 @@ from .arch.drct_arch import DRCT -def _get_upscale_pixelshuffle( - state_dict: StateDict, key_prefix: str = "upsample" -) -> int: - upscale = 1 - - for i in range(0, 10, 2): - key = f"{key_prefix}.{i}.weight" - if key not in state_dict: - break - - shape = state_dict[key].shape - num_feat = shape[1] - upscale *= math.isqrt(shape[0] // num_feat) - - return upscale - - class DRCTArch(Architecture[DRCT]): def __init__(self) -> None: super().__init__( @@ -105,7 +88,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]: if "conv_last.weight" in state_dict: upsampler = "pixelshuffle" - upscale = _get_upscale_pixelshuffle(state_dict, "upsample") + upscale, _ = get_pixelshuffle_params(state_dict, "upsample") else: upsampler = "" upscale = 1 diff --git a/libs/spandrel/spandrel/architectures/GRL/__init__.py b/libs/spandrel/spandrel/architectures/GRL/__init__.py index 1a12308f..43dbc7e9 100644 --- a/libs/spandrel/spandrel/architectures/GRL/__init__.py +++ b/libs/spandrel/spandrel/architectures/GRL/__init__.py @@ -6,7 +6,12 @@ import torch from typing_extensions import override -from spandrel.util import KeyCondition, get_scale_and_output_channels, get_seq_len +from spandrel.util import ( + KeyCondition, + get_pixelshuffle_params, + get_scale_and_output_channels, + get_seq_len, +) from ...__helpers.canonicalize import remove_common_prefix from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict @@ -50,7 +55,6 @@ def _get_output_params(state_dict: StateDict, in_channels: int): upsampler: str upscale: int - num_out_feats = 64 # hard-coded if ( "conv_before_upsample.0.weight" in state_dict and "upsample.up.0.weight" in state_dict @@ -58,10 +62,7 @@ def _get_output_params(state_dict: StateDict, in_channels: int): upsampler = "pixelshuffle" out_channels = state_dict["conv_last.weight"].shape[0] - upscale = 1 - for i in range(0, get_seq_len(state_dict, "upsample.up"), 2): - shape = state_dict[f"upsample.up.{i}.weight"].shape[0] - upscale *= int(math.sqrt(shape // num_out_feats)) + upscale, _ = get_pixelshuffle_params(state_dict, "upsample.up") elif "upsample.up.0.weight" in state_dict: upsampler = "pixelshuffledirect" upscale, out_channels = get_scale_and_output_channels( diff --git a/libs/spandrel/spandrel/architectures/HAT/__init__.py b/libs/spandrel/spandrel/architectures/HAT/__init__.py index 42d5dc0e..3dfdb9a4 100644 --- a/libs/spandrel/spandrel/architectures/HAT/__init__.py +++ b/libs/spandrel/spandrel/architectures/HAT/__init__.py @@ -2,7 +2,7 @@ from typing_extensions import override -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from ...__helpers.model_descriptor import ( Architecture, @@ -109,10 +109,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]: embed_dim = state_dict["conv_first.weight"].shape[0] num_feat = state_dict["conv_last.weight"].shape[1] - upscale = 1 - for i in range(0, get_seq_len(state_dict, "upsample"), 2): - shape = state_dict[f"upsample.{i}.weight"].shape[0] - upscale *= int(math.sqrt(shape // num_feat)) + upscale, _ = get_pixelshuffle_params(state_dict, "upsample", num_feat) window_size = int(math.sqrt(state_dict["relative_position_index_SA"].shape[0])) overlap_ratio = _get_overlap_ratio( diff --git a/libs/spandrel/spandrel/architectures/RGT/__init__.py b/libs/spandrel/spandrel/architectures/RGT/__init__.py index 1909ec2c..7bd715a2 100644 --- a/libs/spandrel/spandrel/architectures/RGT/__init__.py +++ b/libs/spandrel/spandrel/architectures/RGT/__init__.py @@ -4,7 +4,7 @@ from typing_extensions import override -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from ...__helpers.model_descriptor import ( Architecture, @@ -133,13 +133,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]: ) break - upscale = 1 - for i in range(0, 10, 2): - key = f"upsample.{i}.weight" - if key in state_dict: - shape = state_dict[key].shape - num_feat = shape[1] - upscale *= math.isqrt(shape[0] // num_feat) + upscale, _ = get_pixelshuffle_params(state_dict, "upsample") split_size = _get_split_size(state_dict) diff --git a/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py b/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py index 15ca7f51..93ade19c 100644 --- a/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py +++ b/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py @@ -2,7 +2,7 @@ from typing_extensions import override -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from ...__helpers.model_descriptor import ( Architecture, @@ -102,11 +102,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]: math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans) ) else: - num_feat = 64 # hard-coded constant - upscale = 1 - for i in range(0, get_seq_len(state_dict, "upsample"), 2): - shape = state_dict[f"upsample.{i}.weight"].shape[0] - upscale *= int(math.sqrt(shape // num_feat)) + upscale, _ = get_pixelshuffle_params(state_dict, "upsample") window_size = int( math.sqrt( diff --git a/libs/spandrel/spandrel/architectures/SwinIR/__init__.py b/libs/spandrel/spandrel/architectures/SwinIR/__init__.py index 7424ef25..c237a780 100644 --- a/libs/spandrel/spandrel/architectures/SwinIR/__init__.py +++ b/libs/spandrel/spandrel/architectures/SwinIR/__init__.py @@ -3,7 +3,7 @@ from torch import nn from typing_extensions import override -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from ...__helpers.model_descriptor import ( Architecture, @@ -84,15 +84,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]: for _upsample_key in upsample_keys: upscale *= 2 elif upsampler == "pixelshuffle": - upsample_keys = [ - x - for x in state_dict - if "upsample" in x and "conv" not in x and "bias" not in x - ] - for upsample_key in upsample_keys: - shape = state_dict[upsample_key].shape[0] - upscale *= math.sqrt(shape // num_feat) - upscale = int(upscale) + upscale, num_feat = get_pixelshuffle_params(state_dict, "upsample") elif upsampler == "pixelshuffledirect": upscale = int( math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch) diff --git a/libs/spandrel/spandrel/util/__init__.py b/libs/spandrel/spandrel/util/__init__.py index c78186d2..6527fb5e 100644 --- a/libs/spandrel/spandrel/util/__init__.py +++ b/libs/spandrel/spandrel/util/__init__.py @@ -111,6 +111,35 @@ def is_square(n: int) -> bool: ) +def get_pixelshuffle_params( + state_dict: Mapping[str, object], + upsample_key: str = "upsample", + default_nf: int = 64, +) -> tuple[int, int]: + """ + This will detect the upscale factor and number of features of a pixelshuffle module in the state dict. + + A pixelshuffle module is a sequence of alternating up convolutions and pixelshuffle. + The class of this module is commonyl called `Upsample`. + Examples of such modules can be found in most SISR architectures, such as SwinIR, HAT, RGT, and many more. + """ + upscale = 1 + num_feat = default_nf + + for i in range(0, 10, 2): + key = f"{upsample_key}.{i}.weight" + if key not in state_dict: + break + + tensor = state_dict[key] + # we'll assume that the state dict contains tensors + shape: tuple[int, ...] = tensor.shape # type: ignore + num_feat = shape[1] + upscale *= math.isqrt(shape[0] // num_feat) + + return upscale, num_feat + + def store_hyperparameters(*, extra_parameters: Mapping[str, object] = {}): """ Stores the hyperparameters of a class in a `hyperparameters` attribute. @@ -170,9 +199,10 @@ def new_init(self: C, **kwargs): __all__ = [ - "KeyCondition", "get_first_seq_index", - "get_seq_len", + "get_pixelshuffle_params", "get_scale_and_output_channels", + "get_seq_len", + "KeyCondition", "store_hyperparameters", ] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py index b6f931f0..7cbbe71d 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py @@ -8,7 +8,7 @@ SizeRequirements, StateDict, ) -from spandrel.util import KeyCondition, get_seq_len +from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len from .arch.SRFormer import SRFormer @@ -76,12 +76,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRFormer]: upscale = 4 # only supported scale elif "conv_before_upsample.0.weight" in state_dict: upsampler = "pixelshuffle" - - num_feat = 64 # hard-coded constant - upscale = 1 - for i in range(0, get_seq_len(state_dict, "upsample"), 2): - shape = state_dict[f"upsample.{i}.weight"].shape[0] - upscale *= int(math.sqrt(shape // num_feat)) + upscale, _ = get_pixelshuffle_params(state_dict, "upsample") elif "upsample.0.weight" in state_dict: upsampler = "pixelshuffledirect" upscale = int(