Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pad images to size requirements in call API #137

Merged
merged 4 commits into from
Jan 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions src/spandrel/__helpers/model_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,50 @@ def check(self, width: int, height: int) -> bool:
"""
Returns whether the given width and height satisfy the size requirements.
"""
if width < self.minimum or height < self.minimum:
return False
return self.get_padding(width, height) == (0, 0)

def get_padding(self, width: int, height: int) -> tuple[int, int]:
"""
Given an image size, this returns the minimum amount of padding necessary to satisfy the size requirements. The returned padding is in the format `(pad_width, pad_height)` and is guaranteed to be non-negative.
"""

def ceil_modulo(x: int, mod: int) -> int:
if x % mod == 0:
return x
return (x // mod + 1) * mod

w: int = max(self.minimum, width)
h: int = max(self.minimum, height)

w = ceil_modulo(w, self.multiple_of)
h = ceil_modulo(h, self.multiple_of)

if self.square:
w = h = max(w, h)

return w - width, h - height

if width % self.multiple_of != 0 or height % self.multiple_of != 0:
return False

if self.square and width != height:
return False
def _pad(t: torch.Tensor, req: SizeRequirements):
w = t.shape[-1]
h = t.shape[-2]

return True
pad_w, pad_h = req.get_padding(w, h)

if pad_w or pad_h:
# reflect padding only allows a maximum padding of size - 1
reflect_pad_w = min(pad_w, w - 1)
reflect_pad_h = min(pad_h, h - 1)
t = torch.nn.functional.pad(t, (0, reflect_pad_w, 0, reflect_pad_h), "reflect")

# do the rest of the padding (if any) with replicate, which has no such restrictions
pad_w -= reflect_pad_w
pad_h -= reflect_pad_h
t = torch.nn.functional.pad(t, (0, pad_w, 0, pad_h), "replicate")

return True, t
else:
return False, t


Purpose = Literal["SR", "FaceSR", "Inpainting", "Restoration"]
Expand Down Expand Up @@ -200,6 +234,8 @@ def __init__(
Size requirements for the input image. E.g. minimum size.

Requirements are specific to individual models and may be different for models of the same architecture.

Users of spandrel's call API can largely ignore size requirements, because the call API will automatically pad the input image to satisfy the requirements. Size requirements might still be useful for user code that tiles images by allowing it to pick an optimal tile size to avoid padding.
"""
self.tiling: ModelTiling = tiling
"""
Expand Down Expand Up @@ -424,16 +460,36 @@ def __call__(self, image: Tensor) -> Tensor:
"""
Takes a single image tensor as input and returns a single image tensor as output.

The `image` tensor must be a 4D tensor with shape `(1, input_channels, H, W)`. The width and height are expected to satisfy the `size_requirements` of the model. The data type (float32, float16, bfloat16) and device of the `image` tensor must be the same as the model. The range of the `image` tensor must be ``[0, 1]``.
The `image` tensor must be a 4D tensor with shape `(1, input_channels, H, W)`. The data type (float32, float16, bfloat16) and device of the `image` tensor must be the same as the model. The range of the `image` tensor must be ``[0, 1]``.

The output tensor will be a 4D tensor with shape `(1, output_channels, H*scale, W*scale)`. The data type and device of the output tensor will be the same as the `image` tensor. The range of the output tensor will be ``[0, 1]``.

If the width and height of the `image` tensor do not satisfy the `size_requirements` of the model, then the `image` tensor will be padded to satisfy the requirements. The additional padding will be removed from the output tensor before returning it. If the image already satisfies the requirements, then no padding will be added.
"""
if len(image.shape) != 4:
raise ValueError(
f"Expected image tensor to have 4 dimensions, but got {image.shape}"
)

_, _, h, w = image.shape

# satisfy size requirements
did_pad, image = _pad(image, self.size_requirements)

# call model
output = self._call_fn(self.model, image)
assert isinstance(
output, Tensor
), f"Expected {type(self.model).__name__} model to returns a tensor, but got {type(output)}"
return output.clamp_(0, 1)

# guarantee range
output = output.clamp_(0, 1)

# remove padding
if did_pad:
output = output[..., : h * self.scale, : w * self.scale]

return output


class MaskedImageModelDescriptor(ModelBase[T], Generic[T]):
Expand Down Expand Up @@ -484,18 +540,50 @@ def __call__(self, image: Tensor, mask: Tensor) -> Tensor:

The data type (float32, float16, bfloat16) and device of the `image` and `mask` tensors must be the same as the model.

The `image` tensor must be a 4D tensor with shape `(1, input_channels, H, W)`. The width and height are expected to satisfy the `size_requirements` of the model. The range of the `image` tensor must be ``[0, 1]``.
The `image` tensor must be a 4D tensor with shape `(1, input_channels, H, W)`. The range of the `image` tensor must be ``[0, 1]``.

The `mask` tensor must be a 4D tensor with shape `(1, 1, H, W)`. The width and height must be the same as `image` tensor. The values of the `mask` tensor must be either 0 (keep) or 1 (inpaint).

The output tensor will be a 4D tensor with shape `(1, output_channels, H, W)`. The data type and device of the output tensor will be the same as the `image` tensor. The range of the output tensor will be ``[0, 1]``.

If the width and height of the `image` tensor do not satisfy the `size_requirements` of the model, then the `image` tensor will be padded to satisfy the requirements. The additional padding will be removed from the output tensor before returning it. If the image already satisfies the requirements, then no padding will be added.
"""
if len(image.shape) != 4:
raise ValueError(
f"Expected image tensor to have 4 dimensions, but got {image.shape}"
)
if len(mask.shape) != 4:
raise ValueError(
f"Expected mask tensor to have 4 dimensions, but got {mask.shape}"
)

_, _, h, w = image.shape

# check mask
mask_shape = torch.Size([1, 1, h, w])
if mask.shape != mask_shape:
raise ValueError(
f"Expected mask shape to be {mask_shape}, but got {mask.shape}"
)

# satisfy size requirements
did_pad, image = _pad(image, self.size_requirements)
_, mask = _pad(mask, self.size_requirements)

# call model
output = self._call_fn(self.model, image, mask)
assert isinstance(
output, Tensor
), f"Expected {type(self.model).__name__} model to returns a tensor, but got {type(output)}"
return output.clamp_(0, 1)

# guarantee range
output = output.clamp_(0, 1)

# remove padding
if did_pad:
output = output[..., : h * self.scale, : w * self.scale]

return output


ModelDescriptor = Union[
Expand Down