From 11af1daa3630919f3906d47a8a927f9c19996be5 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Sat, 20 Jan 2024 23:10:49 +0100 Subject: [PATCH] Handle large padding --- src/spandrel/__helpers/model_descriptor.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/spandrel/__helpers/model_descriptor.py b/src/spandrel/__helpers/model_descriptor.py index 6919f7ef..37fb9fa5 100644 --- a/src/spandrel/__helpers/model_descriptor.py +++ b/src/spandrel/__helpers/model_descriptor.py @@ -101,7 +101,17 @@ def _pad(t: torch.Tensor, req: SizeRequirements): pad_w, pad_h = req.get_padding(w, h) if pad_w or pad_h: - return True, torch.nn.functional.pad(t, (0, pad_w, 0, pad_h), "reflect") + # reflect padding only allows a maximum padding of size - 1 + reflect_pad_w = min(pad_w, w - 1) + reflect_pad_h = min(pad_h, h - 1) + t = torch.nn.functional.pad(t, (0, reflect_pad_w, 0, reflect_pad_h), "reflect") + + # do the rest of the padding (if any) with replicate, which has no such restrictions + pad_w -= reflect_pad_w + pad_h -= reflect_pad_h + t = torch.nn.functional.pad(t, (0, pad_w, 0, pad_h), "replicate") + + return True, t else: return False, t