Skip to content

Commit

Permalink
feature: integrates spandrel for upscaling
Browse files Browse the repository at this point in the history
  • Loading branch information
jaydrennan committed Jan 21, 2024
1 parent 1bf53e4 commit 65deb1a
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 33 deletions.
4 changes: 2 additions & 2 deletions imaginairy/api/generate_compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _generate_single_image(
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.enhancers.upscale import upscale_image
from imaginairy.modules.midas.api import torch_image_to_depth_map
from imaginairy.samplers import SOLVER_LOOKUP
from imaginairy.samplers.editing import CFGEditingDenoiser
Expand Down Expand Up @@ -534,7 +534,7 @@ def _generate_composition_image(
result = _generate_single_image(composition_prompt, dtype=dtype)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.enhancers.upscale import upscale_image

img = upscale_image(img)

Expand Down
6 changes: 3 additions & 3 deletions imaginairy/api/generate_refiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def generate_single_image(
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.enhancers.upscale import upscale_image
from imaginairy.samplers import SolverName
from imaginairy.schema import ImagineResult
from imaginairy.utils import get_device, randn_seeded
Expand Down Expand Up @@ -587,7 +587,7 @@ def _generate_composition_image(
)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.enhancers.upscale import upscale_image

if prompt.fix_faces:
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
Expand All @@ -596,7 +596,7 @@ def _generate_composition_image(
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
with logging_context.timing("upscaling"):
img = upscale_image(img, ultrasharp=True)
img = upscale_image(img, upscaler_model="realsharp")

img = img.resize(
(target_width, target_height),
Expand Down
28 changes: 17 additions & 11 deletions imaginairy/cli/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
type=float,
help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.",
)
@click.option(
"--upscale-model",
multiple=True,
type=str,
help="Specify one or more upscale models to use.",
)
@click.command("upscale")
def upscale_cmd(image_filepaths, outdir, fix_faces, fix_faces_fidelity):
def upscale_cmd(image_filepaths, outdir, fix_faces, fix_faces_fidelity, upscale_model):
"""
Upscale an image 4x using AI.
"""
Expand All @@ -32,7 +38,7 @@ def upscale_cmd(image_filepaths, outdir, fix_faces, fix_faces_fidelity):
from tqdm import tqdm

from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.enhancers.upscale import upscale_image
from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import glob_expand_paths

Expand All @@ -44,12 +50,12 @@ def upscale_cmd(image_filepaths, outdir, fix_faces, fix_faces_fidelity):
img = LazyLoadingImage(url=p)
else:
img = LazyLoadingImage(filepath=p)
logger.info(
f"Upscaling {p} from {img.width}x{img.height} to {img.width * 4}x{img.height * 4} and saving it to {savepath}"
)

img = upscale_image(img)
if fix_faces:
img = enhance_faces(img, fidelity=fix_faces_fidelity)

img.save(os.path.join(outdir, os.path.basename(p)))
for model in upscale_model:
logger.info(
f"Upscaling {p} from {img.width}x{img.height} to {img.width * 4}x{img.height * 4} and saving it to {savepath}"
)

img = upscale_image(img, model)
if fix_faces:
img = enhance_faces(img, fidelity=fix_faces_fidelity)
img.save(os.path.join(outdir, os.path.basename(p)))
46 changes: 46 additions & 0 deletions imaginairy/enhancers/upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torchvision.transforms.functional as F
from spandrel import ImageModelDescriptor, ModelLoader
from torchvision import transforms

from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import get_device
from imaginairy.utils.downloads import get_cached_url_path

upscale_models = {
"ultrasharp": "https://huggingface.co/lokCX/4x-Ultrasharp/resolve/1856559b50de25116a7c07261177dd128f1f5664/4x-UltraSharp.pth",
"realesrgan": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"HAT": "https://huggingface.co/Acly/hat/resolve/main/HAT_SRx4_ImageNet-pretrain.pth?download=true",
}


def upscale_image(img: LazyLoadingImage, upscaler_model: str = "realesrgan"):
model_path = get_cached_url_path(upscale_models[upscaler_model])
model = ModelLoader().load_from_file(model_path)

assert isinstance(model, ImageModelDescriptor)

device = get_device()
model.to(device).eval()

image_tensor = load_image(img, device)

def process(image: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return model(image)

upscaled_img = process(image_tensor)

upscaled_img = upscaled_img.squeeze(0)
image = F.to_pil_image(upscaled_img)

return image


def load_image(img: LazyLoadingImage, device: str):

transform = transforms.ToTensor()
image_tensor = transform(img.as_pillow())

image_tensor = image_tensor.unsqueeze(0)
return image_tensor.to(device)
43 changes: 26 additions & 17 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ colorama==0.4.6
# mkdocs-material
coverage==7.4.0
# via -r requirements-dev.in
diffusers==0.25.0
diffusers==0.25.1
# via imaginAIry (setup.py)
einops==0.7.0
# via imaginAIry (setup.py)
# via
# imaginAIry (setup.py)
# spandrel
exceptiongroup==1.2.0
# via
# anyio
# pytest
fastapi==0.108.0
fastapi==0.109.0
# via imaginAIry (setup.py)
filelock==3.13.1
# via
Expand All @@ -66,7 +68,7 @@ ftfy==6.1.3
# open-clip-torch
ghp-import==2.1.0
# via mkdocs
griffe==0.38.1
griffe==0.39.1
# via mkdocstrings-python
h11==0.14.0
# via
Expand Down Expand Up @@ -94,23 +96,23 @@ iniconfig==2.0.0
# via pytest
jaxtyping==0.2.25
# via imaginAIry (setup.py)
jinja2==3.1.2
jinja2==3.1.3
# via
# mkdocs
# mkdocs-material
# mkdocstrings
# torch
kornia==0.7.1
# via imaginAIry (setup.py)
markdown==3.5.1
markdown==3.5.2
# via
# mkdocs
# mkdocs-autorefs
# mkdocs-click
# mkdocs-material
# mkdocstrings
# pymdown-extensions
markupsafe==2.1.3
markupsafe==2.1.4
# via
# jinja2
# mkdocs
Expand All @@ -126,15 +128,15 @@ mkdocs-autorefs==0.5.0
# via mkdocstrings
mkdocs-click==0.8.1
# via -r requirements-dev.in
mkdocs-material==9.5.3
mkdocs-material==9.5.4
# via -r requirements-dev.in
mkdocs-material-extensions==1.3.1
# via mkdocs-material
mkdocstrings[python]==0.24.0
# via
# -r requirements-dev.in
# mkdocstrings-python
mkdocstrings-python==1.7.5
mkdocstrings-python==1.8.0
# via mkdocstrings
mpmath==1.3.0
# via sympy
Expand All @@ -152,11 +154,12 @@ numpy==1.24.4
# jaxtyping
# opencv-python
# scipy
# spandrel
# torchvision
# transformers
omegaconf==2.3.0
# via imaginAIry (setup.py)
open-clip-torch==2.23.0
open-clip-torch==2.24.0
# via imaginAIry (setup.py)
opencv-python==4.9.0.80
# via imaginAIry (setup.py)
Expand All @@ -183,11 +186,11 @@ platformdirs==4.1.0
# mkdocstrings
pluggy==1.3.0
# via pytest
protobuf==4.25.1
protobuf==4.25.2
# via
# imaginAIry (setup.py)
# open-clip-torch
psutil==5.9.7
psutil==5.9.8
# via imaginAIry (setup.py)
pydantic==2.5.3
# via
Expand Down Expand Up @@ -246,15 +249,16 @@ requests==2.31.0
# transformers
responses==0.24.1
# via -r requirements-dev.in
ruff==0.1.11
ruff==0.1.14
# via -r requirements-dev.in
safetensors==0.4.1
# via
# diffusers
# imaginAIry (setup.py)
# spandrel
# timm
# transformers
scipy==1.10.1
scipy==1.12.0
# via
# imaginAIry (setup.py)
# torchdiffeq
Expand All @@ -266,7 +270,9 @@ sniffio==1.3.0
# via
# anyio
# httpx
starlette==0.32.0.post1
spandrel==0.2.1
# via imaginAIry (setup.py)
starlette==0.35.1
# via fastapi
sympy==1.12
# via torch
Expand All @@ -289,6 +295,7 @@ torch==2.1.2
# imaginAIry (setup.py)
# kornia
# open-clip-torch
# spandrel
# timm
# torchdiffeq
# torchvision
Expand All @@ -298,6 +305,7 @@ torchvision==0.16.2
# via
# imaginAIry (setup.py)
# open-clip-torch
# spandrel
# timm
tqdm==4.66.1
# via
Expand All @@ -309,7 +317,7 @@ transformers==4.36.2
# via imaginAIry (setup.py)
typeguard==2.13.3
# via jaxtyping
types-pillow==10.1.0.20240106
types-pillow==10.2.0.20240111
# via -r requirements-dev.in
types-psutil==5.9.5.20240106
# via -r requirements-dev.in
Expand All @@ -326,14 +334,15 @@ typing-extensions==4.9.0
# mypy
# pydantic
# pydantic-core
# spandrel
# torch
# uvicorn
urllib3==2.1.0
# via
# requests
# responses
# types-requests
uvicorn==0.25.0
uvicorn==0.26.0
# via imaginAIry (setup.py)
watchdog==3.0.0
# via mkdocs
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_git_revision_hash() -> str:
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64' and sys_platform == 'linux'",
"kornia>=0.6",
"uvicorn>=0.16.0",
"spandrel>=0.1.8",
# "xformers>=0.0.22; sys_platform!='darwin' and platform_machine!='aarch64'",
],
# don't specify maximum python versions as it can cause very long dependency resolution issues as the resolver
Expand Down

0 comments on commit 65deb1a

Please sign in to comment.