Skip to content

Commit

Permalink
fix: reorganizing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
jaydrennan committed Jan 26, 2024
1 parent 7ac3f6a commit fba435f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 19 deletions.
18 changes: 13 additions & 5 deletions imaginairy/api/upscale.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from PIL import Image
from typing import TYPE_CHECKING, Union

from imaginairy.config import DEFAULT_UPSCALE_MODEL
from imaginairy.enhancers.upscale import upscale_image
from imaginairy.schema import LazyLoadingImage

if TYPE_CHECKING:
from PIL import Image

from imaginairy.schema import LazyLoadingImage


def upscale(
img: LazyLoadingImage | Image.Image | str,
img: "Union[LazyLoadingImage, Image.Image, str]",
upscale_model: str = DEFAULT_UPSCALE_MODEL,
tile_size: int = 512,
tile_pad: int = 50,
repetition: int = 1,
device=None,
) -> Image.Image:
) -> "Image.Image":
"""
Upscales an image using a specified super-resolution model.
Expand All @@ -31,6 +34,11 @@ def upscale(
Returns:
Image.Image: The upscaled image as a PIL Image object.
"""
from PIL import Image

from imaginairy.enhancers.upscale import upscale_image
from imaginairy.schema import LazyLoadingImage

if isinstance(img, str):
if img.startswith("https://"):
img = LazyLoadingImage(url=img)
Expand Down
18 changes: 12 additions & 6 deletions imaginairy/enhancers/upscale.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging

from PIL import Image
from typing import TYPE_CHECKING, Union

from imaginairy.config import DEFAULT_UPSCALE_MODEL
from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import get_device

if TYPE_CHECKING:
from PIL import Image

from imaginairy.schema import LazyLoadingImage


upscale_model_lookup = {
# RealESRGAN
"ultrasharp": "https://huggingface.co/lokCX/4x-Ultrasharp/resolve/1856559b50de25116a7c07261177dd128f1f5664/4x-UltraSharp.pth",
Expand All @@ -22,13 +26,13 @@


def upscale_image(
img: LazyLoadingImage | Image.Image,
img: "Union[LazyLoadingImage, Image.Image]",
upscaler_model: str = DEFAULT_UPSCALE_MODEL,
tile_size: int = 512,
tile_pad: int = 50,
repetition: int = 1,
device=None,
) -> Image.Image:
) -> "Image.Image":
"""
Upscales an image using a specified super-resolution model.
Expand Down Expand Up @@ -95,12 +99,14 @@ def upscale_image(
return image


def load_image(img: LazyLoadingImage | Image.Image):
def load_image(img: "Union[LazyLoadingImage, Image.Image]"):
"""
Converts a LazyLoadingImage or PIL Image into a PyTorch tensor.
"""
from torchvision import transforms

from imaginairy.schema import LazyLoadingImage

if isinstance(img, LazyLoadingImage):
img = img.as_pillow()
transform = transforms.ToTensor()
Expand Down
5 changes: 2 additions & 3 deletions imaginairy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,9 @@ def glob_expand_paths(paths):
p = os.path.expanduser(p)
if os.path.exists(p) and os.path.isfile(p):
expanded_paths.append(p)
elif os.path.exists(p) and os.path.isdir(p):
expanded_paths.extend(glob.glob(os.path.expanduser(p)))
else:
print(f"Warning: {p} does not exist.")
expanded_paths.extend(glob.glob(os.path.expanduser(p)))

return expanded_paths


Expand Down
14 changes: 9 additions & 5 deletions imaginairy/utils/tile_up.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import logging
import math
from typing import TYPE_CHECKING

import torch
from torch import Tensor
if TYPE_CHECKING:
import torch
from torch import Tensor

logger = logging.getLogger(__name__)


def tile_process(
img: Tensor,
img: "Tensor",
scale: int,
model: torch.nn.Module,
model: "torch.nn.Module",
tile_size: int = 512,
tile_pad: int = 50,
) -> Tensor:
) -> "Tensor":
"""
Process an image by tiling it, processing each tile, and then merging them back into one image.
Expand All @@ -27,6 +29,8 @@ def tile_process(
Returns:
Tensor: The processed output image.
"""
import torch

batch, channel, height, width = img.shape
output_height = height * scale
output_width = width * scale
Expand Down

0 comments on commit fba435f

Please sign in to comment.