Skip to content

Commit

Permalink
Fix RPLKSR DySample (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine authored Jul 14, 2024
1 parent 7d4923b commit 9813195
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions libs/spandrel/spandrel/architectures/PLKSR/__arch/RealPLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ def __init__(
in_ch: int = 3
out_ch: int = 3

self.dysample = dysample
self.upscaling_factor = upscaling_factor

if not self.training:
dropout = 0

Expand All @@ -144,8 +141,8 @@ def __init__(
torch.repeat_interleave, repeats=upscaling_factor**2, dim=1
)

if dysample and upscaling_factor != 1:
groups = out_ch if 3 * upscaling_factor**2 < 4 else 4
if dysample:
groups = out_ch if upscaling_factor % 2 != 0 else 4
self.to_img = DySample(
in_ch * upscaling_factor**2,
out_ch,
Expand All @@ -158,6 +155,4 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.feats(x) + self.repeat_op(x)
if not self.dysample or (self.dysample and self.upscaling_factor != 1):
x = self.to_img(x)
return x
return self.to_img(x)

0 comments on commit 9813195

Please sign in to comment.