Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: improved performance in various methods in Image and ImageList #879

Merged
merged 15 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,13 @@ def _check_resize_errors(new_width: int, new_height: int) -> None:
_check_bounds("new_height", new_height, lower_bound=_ClosedBound(1))


def _check_crop_errors_and_warnings(
def _check_crop_warnings(
x: int,
y: int,
width: int,
height: int,
min_width: int,
min_height: int,
plural: bool,
) -> None:
_check_bounds("x", x, lower_bound=_ClosedBound(0))
_check_bounds("y", y, lower_bound=_ClosedBound(0))
_check_bounds("width", width, lower_bound=_ClosedBound(1))
_check_bounds("height", height, lower_bound=_ClosedBound(1))

if x >= min_width or y >= min_height:
warnings.warn(
f"The specified bounding rectangle does not contain any content of {'at least one' if plural else 'the'} image. Therefore {'these images' if plural else 'the image'} will be blank.",
Expand All @@ -35,6 +28,18 @@ def _check_crop_errors_and_warnings(
)


def _check_crop_errors(
x: int,
y: int,
width: int,
height: int,
) -> None:
_check_bounds("x", x, lower_bound=_ClosedBound(0))
_check_bounds("y", y, lower_bound=_ClosedBound(0))
_check_bounds("width", width, lower_bound=_ClosedBound(1))
_check_bounds("height", height, lower_bound=_ClosedBound(1))


def _check_adjust_brightness_errors_and_warnings(factor: float, plural: bool) -> None:
_check_bounds("factor", factor, lower_bound=_ClosedBound(0))
if factor == 1:
Expand Down
12 changes: 2 additions & 10 deletions src/safeds/data/image/containers/_empty_image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_check_adjust_color_balance_errors_and_warnings,
_check_adjust_contrast_errors_and_warnings,
_check_blur_errors_and_warnings,
_check_crop_errors_and_warnings,
_check_crop_errors,
_check_remove_images_with_size_errors,
_check_resize_errors,
_check_sharpen_errors_and_warnings,
Expand Down Expand Up @@ -161,15 +161,7 @@ def convert_to_grayscale(self) -> ImageList:

def crop(self, x: int, y: int, width: int, height: int) -> ImageList:
_EmptyImageList._warn_empty_image_list()
_check_crop_errors_and_warnings(
x,
y,
width,
height,
x + 1,
y + 1,
plural=True,
) # Disable x|y >= min_width|min_height check with min_width|min_height=x|y+1
_check_crop_errors(x, y, width, height)
return _EmptyImageList()

def flip_vertically(self) -> ImageList:
Expand Down
157 changes: 105 additions & 52 deletions src/safeds/data/image/containers/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
_check_adjust_color_balance_errors_and_warnings,
_check_adjust_contrast_errors_and_warnings,
_check_blur_errors_and_warnings,
_check_crop_errors_and_warnings,
_check_crop_errors,
_check_crop_warnings,
_check_resize_errors,
_check_sharpen_errors_and_warnings,
)
Expand Down Expand Up @@ -46,7 +47,7 @@ def _filter_edges_kernel() -> Tensor:

if Image._filter_edges_kernel_cache is None:
Image._filter_edges_kernel_cache = (
torch.tensor([[-1.0, -1.0, -1.0], [-1.0, 8.0, -1.0], [-1.0, -1.0, -1.0]])
torch.tensor([[-1.0, -1.0, -1.0], [-1.0, 8.0, -1.0], [-1.0, -1.0, -1.0]], dtype=torch.float16)
.unsqueeze(dim=0)
.unsqueeze(dim=0)
.to(_get_device())
Expand Down Expand Up @@ -118,7 +119,13 @@ def from_bytes(data: bytes) -> Image:
return Image(image_tensor=torchvision.io.decode_image(input_tensor).to(_get_device()))

def __init__(self, image_tensor: Tensor) -> None:
self._image_tensor: Tensor = image_tensor
import torch

self._image_tensor: Tensor
if image_tensor.dtype != torch.uint8:
self._image_tensor = torch.clamp(image_tensor, 0, 255).to(torch.uint8)
else:
self._image_tensor = image_tensor

def __eq__(self, other: object) -> bool:
"""
Expand Down Expand Up @@ -444,18 +451,20 @@ def convert_to_grayscale(self) -> Image:

_init_default_device()

if self.channel == 4:
if self.channel == 1:
return self
elif self.channel == 4:
return Image(
torch.cat(
[
func2.rgb_to_grayscale(self._image_tensor[0:3], num_output_channels=3),
self._image_tensor[3].unsqueeze(dim=0),
self._image_tensor[3:4],
],
),
)
else:
else: # channel == 3
return Image(
func2.rgb_to_grayscale(self._image_tensor[0:3], num_output_channels=self.channel),
func2.rgb_to_grayscale(self._image_tensor[0:3], num_output_channels=3),
)

def crop(self, x: int, y: int, width: int, height: int) -> Image:
Expand Down Expand Up @@ -489,7 +498,8 @@ def crop(self, x: int, y: int, width: int, height: int) -> Image:

_init_default_device()

_check_crop_errors_and_warnings(x, y, width, height, self.width, self.height, plural=False)
_check_crop_errors(x, y, width, height)
_check_crop_warnings(x, y, self.width, self.height, plural=False)
return Image(func2.crop(self._image_tensor, y, x, height, width))

def flip_vertically(self) -> Image:
Expand Down Expand Up @@ -552,22 +562,31 @@ def adjust_brightness(self, factor: float) -> Image:
If factor is smaller than 0.
"""
import torch
from torchvision.transforms.v2 import functional as func2

_init_default_device()

_check_adjust_brightness_errors_and_warnings(factor, plural=False)
if self.channel == 4:
return Image(
torch.cat(
[
func2.adjust_brightness(self._image_tensor[0:3], factor * 1.0),
self._image_tensor[3].unsqueeze(dim=0),
],
),
)
if self._image_tensor.size(dim=-3) != 4:
if factor == 0:
return Image(torch.zeros(self._image_tensor.size(), dtype=torch.uint8))
else:
temp_tensor = self._image_tensor * torch.tensor([factor * 1.0], dtype=torch.float16)
torch.clamp(temp_tensor, 0, 255, out=temp_tensor)
return Image(temp_tensor.to(torch.uint8))
else:
return Image(func2.adjust_brightness(self._image_tensor, factor * 1.0))
img_tensor = torch.empty(self._image_tensor.size(), dtype=torch.uint8)
img_tensor[3] = self._image_tensor[3]
if factor == 0:
torch.zeros(
(3, self._image_tensor.size(dim=-2), self._image_tensor.size(dim=-1)),
dtype=torch.uint8,
out=img_tensor[:, 0:3],
)
else:
temp_tensor = self._image_tensor[0:3] * torch.tensor([factor * 1.0], dtype=torch.float16)
torch.clamp(temp_tensor, 0, 255, out=temp_tensor)
img_tensor[0:3] = temp_tensor[:]
return Image(img_tensor)

def add_noise(self, standard_deviation: float) -> Image:
"""
Expand Down Expand Up @@ -595,9 +614,12 @@ def add_noise(self, standard_deviation: float) -> Image:
_init_default_device()

_check_add_noise_errors(standard_deviation)
return Image(
self._image_tensor + torch.normal(0, standard_deviation, self._image_tensor.size()).to(_get_device()) * 255,
)
float_tensor = torch.empty(self._image_tensor.size(), dtype=torch.float16)
torch.normal(0, standard_deviation, self._image_tensor.size(), out=float_tensor)
float_tensor *= 255
float_tensor += self._image_tensor
torch.clamp(float_tensor, 0, 255, out=float_tensor)
return Image(float_tensor.to(torch.uint8))

def adjust_contrast(self, factor: float) -> Image:
"""
Expand All @@ -624,22 +646,26 @@ def adjust_contrast(self, factor: float) -> Image:
If factor is smaller than 0.
"""
import torch
from torchvision.transforms.v2 import functional as func2

_init_default_device()

_check_adjust_contrast_errors_and_warnings(factor, plural=False)

factor *= 1.0
adjusted_factor = (1 - factor) / factor
gray_tensor = self.convert_to_grayscale()._image_tensor[0]
mean = torch.mean(gray_tensor, dim=(-2, -1), dtype=torch.float16)
del gray_tensor
mean *= torch.tensor(adjusted_factor, dtype=torch.float16)
tensor = mean.repeat(min(self.channel, 3), self._image_tensor.size(dim=-2), self._image_tensor.size(dim=-1))
tensor += self._image_tensor[0 : min(self.channel, 3)]
tensor *= factor
torch.clamp(tensor, 0, 255, out=tensor)

if self.channel == 4:
return Image(
torch.cat(
[
func2.adjust_contrast(self._image_tensor[0:3], factor * 1.0),
self._image_tensor[3].unsqueeze(dim=0),
],
),
)
return Image(torch.cat([tensor.to(torch.uint8), self._image_tensor[3:4]], dim=0))
else:
return Image(func2.adjust_contrast(self._image_tensor, factor * 1.0))
return Image(tensor.to(torch.uint8))

def adjust_color_balance(self, factor: float) -> Image:
"""
Expand All @@ -665,10 +691,20 @@ def adjust_color_balance(self, factor: float) -> Image:
OutOfBoundsError
If factor is smaller than 0.
"""
import torch

_check_adjust_color_balance_errors_and_warnings(factor, self.channel, plural=False)
return Image(
self.convert_to_grayscale()._image_tensor * (1.0 - factor * 1.0) + self._image_tensor * (factor * 1.0),
)

factor *= 1.0
if factor == 0:
return self.convert_to_grayscale()
else:
adjusted_factor = (1 - factor) / factor
tensor = self.convert_to_grayscale()._image_tensor * torch.tensor(adjusted_factor, dtype=torch.float16)
tensor += self._image_tensor
tensor *= factor
torch.clamp(tensor, 0, 255, out=tensor)
return Image(tensor.to(torch.uint8))

def blur(self, radius: int) -> Image:
"""
Expand All @@ -692,12 +728,30 @@ def blur(self, radius: int) -> Image:
OutOfBoundsError
If radius is smaller than 0 or equal or greater than the smaller size of the image.
"""
from torchvision.transforms.v2 import functional as func2
import torch

_init_default_device()

float_dtype = torch.float32 if _get_device() != torch.device("cuda") else torch.float16

_check_blur_errors_and_warnings(radius, min(self.width, self.height), plural=False)
return Image(func2.gaussian_blur(self._image_tensor, [radius * 2 + 1, radius * 2 + 1]))

kernel = torch.full(
(self._image_tensor.size(dim=-3), 1, radius * 2 + 1, radius * 2 + 1),
1 / (radius * 2 + 1) ** 2,
dtype=float_dtype,
)
tensor = torch.nn.functional.conv2d(
torch.nn.functional.pad(
self._image_tensor.to(float_dtype),
(radius, radius, radius, radius),
mode="replicate",
),
kernel,
padding="valid",
groups=self._image_tensor.size(dim=-3),
).to(torch.uint8)
return Image(tensor)

def sharpen(self, factor: float) -> Image:
"""
Expand Down Expand Up @@ -734,7 +788,7 @@ def sharpen(self, factor: float) -> Image:
torch.cat(
[
func2.adjust_sharpness(self._image_tensor[0:3], factor * 1.0),
self._image_tensor[3].unsqueeze(dim=0),
self._image_tensor[3:4],
],
),
)
Expand All @@ -759,7 +813,7 @@ def invert_colors(self) -> Image:

if self.channel == 4:
return Image(
torch.cat([func2.invert(self._image_tensor[0:3]), self._image_tensor[3].unsqueeze(dim=0)]),
torch.cat([func2.invert(self._image_tensor[0:3]), self._image_tensor[3:4]]),
)
else:
return Image(func2.invert(self._image_tensor))
Expand Down Expand Up @@ -813,20 +867,19 @@ def find_edges(self) -> Image:

_init_default_device()

edges_tensor = torch.clamp(
torch.nn.functional.conv2d(
self.convert_to_grayscale()._image_tensor.float()[0].unsqueeze(dim=0),
Image._filter_edges_kernel(),
padding="same",
).squeeze(dim=1),
0,
255,
).to(torch.uint8)
edges_tensor_float16 = torch.nn.functional.conv2d(
self.convert_to_grayscale()._image_tensor.to(torch.float16)[0:1],
Image._filter_edges_kernel(),
padding="same",
)
torch.clamp(edges_tensor_float16, 0, 255, out=edges_tensor_float16)
if self.channel == 1:
return Image(edges_tensor_float16.to(torch.uint8))
edges_tensor = edges_tensor_float16.to(torch.uint8)
del edges_tensor_float16
if self.channel == 3:
return Image(edges_tensor.repeat(3, 1, 1))
elif self.channel == 4:
else: # self.channel == 4
return Image(
torch.cat([edges_tensor.repeat(3, 1, 1), self._image_tensor[3].unsqueeze(dim=0)]),
torch.cat([edges_tensor.repeat(3, 1, 1), self._image_tensor[3:4]]),
)
else:
return Image(edges_tensor)
Loading