From fba435feb6c0f44c7760c1914f33f51304d6196d Mon Sep 17 00:00:00 2001 From: jaydrennan Date: Fri, 26 Jan 2024 15:42:22 -0800 Subject: [PATCH] fix: reorganizing imports --- imaginairy/api/upscale.py | 18 +++++++++++++----- imaginairy/enhancers/upscale.py | 18 ++++++++++++------ imaginairy/utils/__init__.py | 5 ++--- imaginairy/utils/tile_up.py | 14 +++++++++----- 4 files changed, 36 insertions(+), 19 deletions(-) diff --git a/imaginairy/api/upscale.py b/imaginairy/api/upscale.py index 83694672..16c39fb6 100644 --- a/imaginairy/api/upscale.py +++ b/imaginairy/api/upscale.py @@ -1,18 +1,21 @@ -from PIL import Image +from typing import TYPE_CHECKING, Union from imaginairy.config import DEFAULT_UPSCALE_MODEL -from imaginairy.enhancers.upscale import upscale_image -from imaginairy.schema import LazyLoadingImage + +if TYPE_CHECKING: + from PIL import Image + + from imaginairy.schema import LazyLoadingImage def upscale( - img: LazyLoadingImage | Image.Image | str, + img: "Union[LazyLoadingImage, Image.Image, str]", upscale_model: str = DEFAULT_UPSCALE_MODEL, tile_size: int = 512, tile_pad: int = 50, repetition: int = 1, device=None, -) -> Image.Image: +) -> "Image.Image": """ Upscales an image using a specified super-resolution model. @@ -31,6 +34,11 @@ def upscale( Returns: Image.Image: The upscaled image as a PIL Image object. """ + from PIL import Image + + from imaginairy.enhancers.upscale import upscale_image + from imaginairy.schema import LazyLoadingImage + if isinstance(img, str): if img.startswith("https://"): img = LazyLoadingImage(url=img) diff --git a/imaginairy/enhancers/upscale.py b/imaginairy/enhancers/upscale.py index 2734526e..e076a34f 100644 --- a/imaginairy/enhancers/upscale.py +++ b/imaginairy/enhancers/upscale.py @@ -1,11 +1,15 @@ import logging - -from PIL import Image +from typing import TYPE_CHECKING, Union from imaginairy.config import DEFAULT_UPSCALE_MODEL -from imaginairy.schema import LazyLoadingImage from imaginairy.utils import get_device +if TYPE_CHECKING: + from PIL import Image + + from imaginairy.schema import LazyLoadingImage + + upscale_model_lookup = { # RealESRGAN "ultrasharp": "https://huggingface.co/lokCX/4x-Ultrasharp/resolve/1856559b50de25116a7c07261177dd128f1f5664/4x-UltraSharp.pth", @@ -22,13 +26,13 @@ def upscale_image( - img: LazyLoadingImage | Image.Image, + img: "Union[LazyLoadingImage, Image.Image]", upscaler_model: str = DEFAULT_UPSCALE_MODEL, tile_size: int = 512, tile_pad: int = 50, repetition: int = 1, device=None, -) -> Image.Image: +) -> "Image.Image": """ Upscales an image using a specified super-resolution model. @@ -95,12 +99,14 @@ def upscale_image( return image -def load_image(img: LazyLoadingImage | Image.Image): +def load_image(img: "Union[LazyLoadingImage, Image.Image]"): """ Converts a LazyLoadingImage or PIL Image into a PyTorch tensor. """ from torchvision import transforms + from imaginairy.schema import LazyLoadingImage + if isinstance(img, LazyLoadingImage): img = img.as_pillow() transform = transforms.ToTensor() diff --git a/imaginairy/utils/__init__.py b/imaginairy/utils/__init__.py index fc9284a9..d13b76bc 100644 --- a/imaginairy/utils/__init__.py +++ b/imaginairy/utils/__init__.py @@ -240,10 +240,9 @@ def glob_expand_paths(paths): p = os.path.expanduser(p) if os.path.exists(p) and os.path.isfile(p): expanded_paths.append(p) - elif os.path.exists(p) and os.path.isdir(p): - expanded_paths.extend(glob.glob(os.path.expanduser(p))) else: - print(f"Warning: {p} does not exist.") + expanded_paths.extend(glob.glob(os.path.expanduser(p))) + return expanded_paths diff --git a/imaginairy/utils/tile_up.py b/imaginairy/utils/tile_up.py index 50ba7b42..cc1ec779 100644 --- a/imaginairy/utils/tile_up.py +++ b/imaginairy/utils/tile_up.py @@ -1,19 +1,21 @@ import logging import math +from typing import TYPE_CHECKING -import torch -from torch import Tensor +if TYPE_CHECKING: + import torch + from torch import Tensor logger = logging.getLogger(__name__) def tile_process( - img: Tensor, + img: "Tensor", scale: int, - model: torch.nn.Module, + model: "torch.nn.Module", tile_size: int = 512, tile_pad: int = 50, -) -> Tensor: +) -> "Tensor": """ Process an image by tiling it, processing each tile, and then merging them back into one image. @@ -27,6 +29,8 @@ def tile_process( Returns: Tensor: The processed output image. """ + import torch + batch, channel, height, width = img.shape output_height = height * scale output_width = width * scale